Added some tests for the Database iterator interface. Updated the post-increment...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Sun, 12 Oct 2014 22:43:44 +0000 (18:43 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_database.cpp

index c61ab51..08baf2f 100644 (file)
@@ -62,6 +62,20 @@ class Database {
           state_(state) { }
     ~iterator() { }
 
+    iterator(const iterator& other)
+        : parent_(other.parent_),
+          state_(other.state_->clone()) { }
+
+    iterator& operator=(iterator copy) {
+      copy.swap(*this);
+      return *this;
+    }
+
+    void swap(iterator& other) throw() {
+      std::swap(this->parent_, other.parent_);
+      std::swap(this->state_, other.state_);
+    }
+
     bool operator==(const iterator& other) const {
       return parent_->equal(state_, other.state_);
     }
@@ -97,6 +111,7 @@ class Database {
   class DatabaseState {
    public:
     virtual ~DatabaseState() { }
+    virtual shared_ptr<DatabaseState> clone() = 0;
   };
 
   virtual bool equal(shared_ptr<DatabaseState> state1,
index 42f73f9..03bfd38 100644 (file)
@@ -28,10 +28,26 @@ class LeveldbDatabase : public Database {
  protected:
   class LeveldbState : public Database::DatabaseState {
    public:
-    explicit LeveldbState(shared_ptr<leveldb::Iterator> iter)
+    explicit LeveldbState(shared_ptr<leveldb::DB> db,
+        shared_ptr<leveldb::Iterator> iter)
         : Database::DatabaseState(),
+          db_(db),
           iter_(iter) { }
 
+    shared_ptr<DatabaseState> clone() {
+      shared_ptr<leveldb::Iterator> new_iter;
+
+      if (iter_.get()) {
+        new_iter.reset(db_->NewIterator(leveldb::ReadOptions()));
+        CHECK(iter_->Valid());
+        new_iter->Seek(iter_->key());
+        CHECK(new_iter->Valid());
+      }
+
+      return shared_ptr<DatabaseState>(new LeveldbState(db_, new_iter));
+    }
+
+    shared_ptr<leveldb::DB> db_;
     shared_ptr<leveldb::Iterator> iter_;
     KV kv_pair_;
   };
index 4796b4d..7b532f8 100644 (file)
@@ -32,11 +32,35 @@ class LmdbDatabase : public Database {
  protected:
   class LmdbState : public Database::DatabaseState {
    public:
-    explicit LmdbState(MDB_cursor* cursor)
+    explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi)
         : Database::DatabaseState(),
-          cursor_(cursor) { }
+          cursor_(cursor),
+          txn_(txn),
+          dbi_(dbi) { }
+
+    shared_ptr<DatabaseState> clone() {
+      MDB_cursor* new_cursor;
+
+      if (cursor_) {
+        int retval;
+        retval = mdb_cursor_open(txn_, *dbi_, &new_cursor);
+        CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
+        MDB_val key;
+        MDB_val val;
+        retval = mdb_cursor_get(cursor_, &key, &val, MDB_GET_CURRENT);
+        CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
+        retval = mdb_cursor_get(new_cursor, &key, &val, MDB_SET);
+        CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval);
+      } else {
+        new_cursor = cursor_;
+      }
+
+      return shared_ptr<DatabaseState>(new LmdbState(new_cursor, txn_, dbi_));
+    }
 
     MDB_cursor* cursor_;
+    MDB_txn* txn_;
+    const MDB_dbi* dbi_;
     KV kv_pair_;
   };
 
index d2b37e6..51d50cc 100644 (file)
@@ -99,13 +99,13 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
   if (!iter->Valid()) {
     iter.reset();
   }
-  shared_ptr<DatabaseState> state(new LeveldbState(iter));
+  shared_ptr<DatabaseState> state(new LeveldbState(db_, iter));
   return const_iterator(this, state);
 }
 
 LeveldbDatabase::const_iterator LeveldbDatabase::end() const {
   shared_ptr<leveldb::Iterator> iter;
-  shared_ptr<DatabaseState> state(new LeveldbState(iter));
+  shared_ptr<DatabaseState> state(new LeveldbState(db_, iter));
   return const_iterator(this, state);
 }
 
index 952e95a..a546c8c 100644 (file)
@@ -134,12 +134,12 @@ LmdbDatabase::const_iterator LmdbDatabase::begin() const {
   retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST);
   CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
 
-  shared_ptr<DatabaseState> state(new LmdbState(cursor));
+  shared_ptr<DatabaseState> state(new LmdbState(cursor, txn_, &dbi_));
   return const_iterator(this, state);
 }
 
 LmdbDatabase::const_iterator LmdbDatabase::end() const {
-  shared_ptr<DatabaseState> state(new LmdbState(NULL));
+  shared_ptr<DatabaseState> state(new LmdbState(NULL, txn_, &dbi_));
   return const_iterator(this, state);
 }
 
index dce92d8..5d5f4ea 100644 (file)
@@ -34,6 +34,32 @@ class DatabaseTest : public MultiDeviceTest<TypeParam> {
     Database::buffer_t value(kValue, kValue + 5);
     return value;
   }
+
+  Database::buffer_t TestAltKey() {
+    const char* kKey = "foo";
+    Database::buffer_t key(kKey, kKey + 3);
+    return key;
+  }
+
+  Database::buffer_t TestAltValue() {
+    const char* kValue = "bar";
+    Database::buffer_t value(kValue, kValue + 3);
+    return value;
+  }
+
+  bool BufferEq(const Database::buffer_t& buf1,
+      const Database::buffer_t& buf2) {
+    if (buf1.size() != buf2.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < buf1.size(); ++i) {
+      if (buf1.at(i) != buf2.at(i)) {
+        return false;
+      }
+    }
+
+    return true;
+  }
 };
 
 TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices);
@@ -114,6 +140,87 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
   EXPECT_EQ(kNumExamples, count);
 }
 
+TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("leveldb");
+  database->open(name, Database::New);
+
+  Database::buffer_t key1 = this->TestAltKey();
+  Database::buffer_t value1 = this->TestAltValue();
+
+  Database::buffer_t key2 = this->TestKey();
+  Database::buffer_t value2 = this->TestValue();
+
+  database->put(&key1, &value1);
+  database->put(&key2, &value2);
+  database->commit();
+
+  Database::const_iterator iter1 = database->begin();
+
+  EXPECT_FALSE(database->end() == iter1);
+
+  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+
+  Database::const_iterator iter2 = ++iter1;
+
+  EXPECT_FALSE(database->end() == iter1);
+  EXPECT_FALSE(database->end() == iter2);
+
+  EXPECT_TRUE(this->BufferEq(iter2->key, key2));
+
+  Database::const_iterator iter3 = ++iter2;
+
+  EXPECT_TRUE(database->end() == iter3);
+
+  iter1 = database->end();
+  iter2 = database->end();
+  iter3 = database->end();
+
+  database->close();
+}
+
+TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("leveldb");
+  database->open(name, Database::New);
+
+  Database::buffer_t key1 = this->TestAltKey();
+  Database::buffer_t value1 = this->TestAltValue();
+
+  Database::buffer_t key2 = this->TestKey();
+  Database::buffer_t value2 = this->TestValue();
+
+  database->put(&key1, &value1);
+  database->put(&key2, &value2);
+  database->commit();
+
+  Database::const_iterator iter1 = database->begin();
+
+  EXPECT_FALSE(database->end() == iter1);
+
+  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+
+  Database::const_iterator iter2 = iter1++;
+
+  EXPECT_FALSE(database->end() == iter1);
+  EXPECT_FALSE(database->end() == iter2);
+
+  EXPECT_TRUE(this->BufferEq(iter2->key, key1));
+  EXPECT_TRUE(this->BufferEq(iter1->key, key2));
+
+  Database::const_iterator iter3 = iter1++;
+
+  EXPECT_FALSE(database->end() == iter3);
+  EXPECT_TRUE(this->BufferEq(iter3->key, key2));
+  EXPECT_TRUE(database->end() == iter1);
+
+  iter1 = database->end();
+  iter2 = database->end();
+  iter3 = database->end();
+
+  database->close();
+}
+
 TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
@@ -155,10 +262,7 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 
   database->close();
 }
@@ -220,10 +324,7 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 
   database->close();
 }
@@ -288,10 +389,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) {
@@ -389,6 +487,87 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
   EXPECT_EQ(kNumExamples, count);
 }
 
+TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("lmdb");
+  database->open(name, Database::New);
+
+  Database::buffer_t key1 = this->TestAltKey();
+  Database::buffer_t value1 = this->TestAltValue();
+
+  Database::buffer_t key2 = this->TestKey();
+  Database::buffer_t value2 = this->TestValue();
+
+  database->put(&key1, &value1);
+  database->put(&key2, &value2);
+  database->commit();
+
+  Database::const_iterator iter1 = database->begin();
+
+  EXPECT_FALSE(database->end() == iter1);
+
+  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+
+  Database::const_iterator iter2 = ++iter1;
+
+  EXPECT_FALSE(database->end() == iter1);
+  EXPECT_FALSE(database->end() == iter2);
+
+  EXPECT_TRUE(this->BufferEq(iter2->key, key2));
+
+  Database::const_iterator iter3 = ++iter2;
+
+  EXPECT_TRUE(database->end() == iter3);
+
+  iter1 = database->end();
+  iter2 = database->end();
+  iter3 = database->end();
+
+  database->close();
+}
+
+TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("lmdb");
+  database->open(name, Database::New);
+
+  Database::buffer_t key1 = this->TestAltKey();
+  Database::buffer_t value1 = this->TestAltValue();
+
+  Database::buffer_t key2 = this->TestKey();
+  Database::buffer_t value2 = this->TestValue();
+
+  database->put(&key1, &value1);
+  database->put(&key2, &value2);
+  database->commit();
+
+  Database::const_iterator iter1 = database->begin();
+
+  EXPECT_FALSE(database->end() == iter1);
+
+  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+
+  Database::const_iterator iter2 = iter1++;
+
+  EXPECT_FALSE(database->end() == iter1);
+  EXPECT_FALSE(database->end() == iter2);
+
+  EXPECT_TRUE(this->BufferEq(iter2->key, key1));
+  EXPECT_TRUE(this->BufferEq(iter1->key, key2));
+
+  Database::const_iterator iter3 = iter1++;
+
+  EXPECT_FALSE(database->end() == iter3);
+  EXPECT_TRUE(this->BufferEq(iter3->key, key2));
+  EXPECT_TRUE(database->end() == iter1);
+
+  iter1 = database->end();
+  iter2 = database->end();
+  iter3 = database->end();
+
+  database->close();
+}
+
 TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
@@ -430,10 +609,7 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 
   database->close();
 }
@@ -494,10 +670,7 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 
   database->close();
 }
@@ -562,10 +735,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
 
   database->get(&key, &new_val);
 
-  EXPECT_EQ(val.size(), new_val.size());
-  for (size_t i = 0; i < val.size(); ++i) {
-    EXPECT_EQ(val.at(i), new_val.at(i));
-  }
+  EXPECT_TRUE(this->BufferEq(val, new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) {