fix test_net for refactor
authorJeff Donahue <jeff.donahue@gmail.com>
Fri, 21 Mar 2014 21:04:11 +0000 (14:04 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 28 Mar 2014 06:42:28 +0000 (23:42 -0700)
src/caffe/test/test_net.cpp

index 77b3516a980e5bfd4f7b85d770c58568c264e43c..4c7f0e7f7ac7e6bb35c2e6526da5e3e1dee319f4 100644 (file)
@@ -18,17 +18,14 @@ namespace caffe {
 template <typename Dtype>
 class NetTest : public ::testing::Test {
  protected:
-  NetTest() : filename(NULL) {
-  }
-
   virtual void SetUp() {  // Create the leveldb
-    filename = tmpnam(NULL);  // get temp name
-    LOG(INFO) << "Using temporary leveldb " << filename;
+    filename_ = tmpnam(NULL);  // get temp name
+    LOG(INFO) << "Using temporary leveldb " << filename_;
     leveldb::DB* db;
     leveldb::Options options;
     options.error_if_exists = true;
     options.create_if_missing = true;
-    leveldb::Status status = leveldb::DB::Open(options, filename, &db);
+    leveldb::Status status = leveldb::DB::Open(options, filename_, &db);
     CHECK(status.ok());
     for (int i = 0; i < 5; ++i) {
       Datum datum;
@@ -49,19 +46,19 @@ class NetTest : public ::testing::Test {
     const string& proto_prefix =
         "name: 'TestNetwork' "
         "layers: { "
-        "  layer { "
-        "    name: 'data' "
-        "    type: 'data' ";
+        "  name: 'data' "
+        "  type: DATA "
+        "  data_param { ";
     const string& proto_suffix =
-        "    batchsize: 1 "
+        "    batch_size: 1 "
         "  } "
         "  top: 'data' "
         "  top: 'label' "
         "} "
         "layers: { "
-        "  layer { "
-        "    name: 'innerproduct' "
-        "    type: 'innerproduct' "
+        "  name: 'innerproduct' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
         "    num_output: 1000 "
         "    weight_filler { "
         "      type: 'gaussian' "
@@ -71,31 +68,26 @@ class NetTest : public ::testing::Test {
         "      type: 'constant' "
         "      value: 0 "
         "    } "
-        "    blobs_lr: 1. "
-        "    blobs_lr: 2. "
-        "    weight_decay: 1. "
-        "    weight_decay: 0. "
         "  } "
+        "  blobs_lr: 1. "
+        "  blobs_lr: 2. "
+        "  weight_decay: 1. "
+        "  weight_decay: 0. "
         "  bottom: 'data' "
         "  top: 'innerproduct' "
         "} "
         "layers: { "
-        "  layer { "
-        "    name: 'loss' "
-        "    type: 'softmax_loss' "
-        "  } "
+        "  name: 'loss' "
+        "  type: SOFTMAX_LOSS "
         "  bottom: 'innerproduct' "
         "  bottom: 'label' "
         "} ";
-    proto = proto_prefix + "source: '" + string(this->filename) +
+    proto_ = proto_prefix + "source: '" + string(this->filename_) +
         "' " + proto_suffix;
   }
 
-  virtual ~NetTest() {
-  }
-
-  char* filename;
-  string proto;
+  char* filename_;
+  string proto_;
 };
 
 typedef ::testing::Types<float, double> Dtypes;
@@ -103,8 +95,7 @@ TYPED_TEST_CASE(NetTest, Dtypes);
 
 TYPED_TEST(NetTest, TestHasBlob) {
   NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
-                                                      &param));
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
   Net<TypeParam> net(param);
   EXPECT_TRUE(net.has_blob("data"));
   EXPECT_TRUE(net.has_blob("label"));
@@ -114,8 +105,7 @@ TYPED_TEST(NetTest, TestHasBlob) {
 
 TYPED_TEST(NetTest, TestGetBlob) {
   NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
-                                                      &param));
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
   Net<TypeParam> net(param);
   EXPECT_EQ(net.blob_by_name("data"), net.blobs()[0]);
   EXPECT_EQ(net.blob_by_name("label"), net.blobs()[1]);
@@ -125,8 +115,7 @@ TYPED_TEST(NetTest, TestGetBlob) {
 
 TYPED_TEST(NetTest, TestHasLayer) {
   NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
-                                                      &param));
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
   Net<TypeParam> net(param);
   EXPECT_TRUE(net.has_layer("data"));
   EXPECT_TRUE(net.has_layer("innerproduct"));
@@ -136,8 +125,7 @@ TYPED_TEST(NetTest, TestHasLayer) {
 
 TYPED_TEST(NetTest, TestGetLayerByName) {
   NetParameter param;
-  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
-                                                      &param));
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto_, &param));
   Net<TypeParam> net(param);
   EXPECT_EQ(net.layer_by_name("data"), net.layers()[0]);
   EXPECT_EQ(net.layer_by_name("innerproduct"), net.layers()[1]);