make test_net use DUMMY_DATA instead of leveldb
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 12 Jun 2014 21:58:07 +0000 (14:58 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Thu, 12 Jun 2014 22:36:03 +0000 (15:36 -0700)
src/caffe/test/test_net.cpp

index 4c7f0e7..1e43a03 100644 (file)
@@ -1,8 +1,6 @@
 // Copyright 2014 BVLC and contributors.
 
 #include <google/protobuf/text_format.h>
-#include <leveldb/db.h>
-#include <sstream>
 #include <string>
 
 #include "gtest/gtest.h"
 
 namespace caffe {
 
-
 template <typename Dtype>
 class NetTest : public ::testing::Test {
  protected:
-  virtual void SetUp() {  // Create the leveldb
-    filename_ = tmpnam(NULL);  // get temp name
-    LOG(INFO) << "Using temporary leveldb " << filename_;
-    leveldb::DB* db;
-    leveldb::Options options;
-    options.error_if_exists = true;
-    options.create_if_missing = true;
-    leveldb::Status status = leveldb::DB::Open(options, filename_, &db);
-    CHECK(status.ok());
-    for (int i = 0; i < 5; ++i) {
-      Datum datum;
-      datum.set_label(i);
-      datum.set_channels(2);
-      datum.set_height(3);
-      datum.set_width(4);
-      std::string* data = datum.mutable_data();
-      for (int j = 0; j < 24; ++j) {
-        data->push_back((uint8_t)i);
-      }
-      std::stringstream ss;
-      ss << i;
-      db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
-    }
-    delete db;
-
-    const string& proto_prefix =
+  virtual void SetUp() {
+    const string& proto =
         "name: 'TestNetwork' "
         "layers: { "
         "  name: 'data' "
-        "  type: DATA "
-        "  data_param { ";
-    const string& proto_suffix =
-        "    batch_size: 1 "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 5 "
+        "    channels: 2 "
+        "    height: 3 "
+        "    width: 4 "
+        "    num: 5 "
+        "    channels: 1 "
+        "    height: 1 "
+        "    width: 1 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
         "  } "
         "  top: 'data' "
         "  top: 'label' "
@@ -82,55 +65,43 @@ class NetTest : public ::testing::Test {
         "  bottom: 'innerproduct' "
         "  bottom: 'label' "
         "} ";
-    proto_ = proto_prefix + "source: '" + string(this->filename_) +
-        "' " + proto_suffix;
+    NetParameter param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
+    net_.reset(new Net<Dtype>(param));
   }
 
-  char* filename_;
-  string proto_;
+  shared_ptr<Net<Dtype> > net_;
 };
 
 typedef ::testing::Types<float, double> Dtypes;
 TYPED_TEST_CASE(NetTest, Dtypes);
 
 TYPED_TEST(NetTest, TestHasBlob) {
-  NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
-  Net<TypeParam> net(param);
-  EXPECT_TRUE(net.has_blob("data"));
-  EXPECT_TRUE(net.has_blob("label"));
-  EXPECT_TRUE(net.has_blob("innerproduct"));
-  EXPECT_FALSE(net.has_blob("loss"));
+  EXPECT_TRUE(this->net_->has_blob("data"));
+  EXPECT_TRUE(this->net_->has_blob("label"));
+  EXPECT_TRUE(this->net_->has_blob("innerproduct"));
+  EXPECT_FALSE(this->net_->has_blob("loss"));
 }
 
 TYPED_TEST(NetTest, TestGetBlob) {
-  NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
-  Net<TypeParam> net(param);
-  EXPECT_EQ(net.blob_by_name("data"), net.blobs()[0]);
-  EXPECT_EQ(net.blob_by_name("label"), net.blobs()[1]);
-  EXPECT_EQ(net.blob_by_name("innerproduct"), net.blobs()[2]);
-  EXPECT_FALSE(net.blob_by_name("loss"));
+  EXPECT_EQ(this->net_->blob_by_name("data"), this->net_->blobs()[0]);
+  EXPECT_EQ(this->net_->blob_by_name("label"), this->net_->blobs()[1]);
+  EXPECT_EQ(this->net_->blob_by_name("innerproduct"), this->net_->blobs()[2]);
+  EXPECT_FALSE(this->net_->blob_by_name("loss"));
 }
 
 TYPED_TEST(NetTest, TestHasLayer) {
-  NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
-  Net<TypeParam> net(param);
-  EXPECT_TRUE(net.has_layer("data"));
-  EXPECT_TRUE(net.has_layer("innerproduct"));
-  EXPECT_TRUE(net.has_layer("loss"));
-  EXPECT_FALSE(net.has_layer("label"));
+  EXPECT_TRUE(this->net_->has_layer("data"));
+  EXPECT_TRUE(this->net_->has_layer("innerproduct"));
+  EXPECT_TRUE(this->net_->has_layer("loss"));
+  EXPECT_FALSE(this->net_->has_layer("label"));
 }
 
 TYPED_TEST(NetTest, TestGetLayerByName) {
-  NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
-  Net<TypeParam> net(param);
-  EXPECT_EQ(net.layer_by_name("data"), net.layers()[0]);
-  EXPECT_EQ(net.layer_by_name("innerproduct"), net.layers()[1]);
-  EXPECT_EQ(net.layer_by_name("loss"), net.layers()[2]);
-  EXPECT_FALSE(net.layer_by_name("label"));
+  EXPECT_EQ(this->net_->layer_by_name("data"), this->net_->layers()[0]);
+  EXPECT_EQ(this->net_->layer_by_name("innerproduct"), this->net_->layers()[1]);
+  EXPECT_EQ(this->net_->layer_by_name("loss"), this->net_->layers()[2]);
+  EXPECT_FALSE(this->net_->layer_by_name("label"));
 }
 
 }  // namespace caffe