Add and test Net::HasLayer and GetLayerByName
authorKai Li <kaili_kloud@163.com>
Sun, 23 Feb 2014 09:08:20 +0000 (17:08 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:41 +0000 (23:04 +0800)
include/caffe/net.hpp
src/caffe/net.cpp
src/caffe/test/test_net.cpp [new file with mode: 0644]

index c6f968b..ad7c052 100644 (file)
@@ -62,6 +62,8 @@ class Net {
   // Access intermediary computation layers, testing with centre image only
   bool HasBlob(const string& blob_name);
   const shared_ptr<Blob<Dtype> > GetBlob(const string& blob_name);
+  bool HasLayer(const string& layer_name);
+  const shared_ptr<Layer<Dtype> > GetLayerByName(const string& layer_name);
 
   // returns the network name.
   inline const string& name() { return name_; }
@@ -96,6 +98,7 @@ class Net {
   // Individual layers in the net
   vector<shared_ptr<Layer<Dtype> > > layers_;
   vector<string> layer_names_;
+  map<string, int> layer_names_index_;
   vector<bool> layer_need_backward_;
   // blobs stores the blobs that store intermediate results between the
   // layers.
index 5633088..75d863e 100644 (file)
@@ -165,6 +165,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   for (size_t i = 0; i < blob_names_.size(); ++i) {
     blob_names_index_[blob_names_[i]] = i;
   }
+  for (size_t i = 0; i < layer_names_.size(); ++i) {
+    layer_names_index_[layer_names_[i]] = i;
+  }
   GetLearningRateAndWeightDecay();
   LOG(INFO) << "Network initialization done.";
   LOG(INFO) << "Memory required for Data " << memory_used*sizeof(Dtype);
@@ -344,12 +347,30 @@ const shared_ptr<Blob<Dtype> > Net<Dtype>::GetBlob(const string& blob_name) {
   if (HasBlob(blob_name)) {
     blob_ptr = blobs_[blob_names_index_[blob_name]];
   } else {
-    blob_ptr.reset(new Blob<Dtype>());
+    blob_ptr.reset((Blob<Dtype>*)(NULL));
     LOG(ERROR) << "Unknown blob name " << blob_name;
   }
   return blob_ptr;
 }
 
+template <typename Dtype>
+bool Net<Dtype>::HasLayer(const string& layer_name) {
+  return layer_names_index_.find(layer_name) != layer_names_index_.end();
+}
+
+template <typename Dtype>
+const shared_ptr<Layer<Dtype> > Net<Dtype>::GetLayerByName(const string& layer_name) {
+  shared_ptr<Layer<Dtype> > layer_ptr;
+  if (HasLayer(layer_name)) {
+    layer_ptr = layers_[layer_names_index_[layer_name]];
+  } else {
+    layer_ptr.reset((Layer<Dtype>*)(NULL));
+    LOG(ERROR) << "Unknown layer name " << layer_name;
+  }
+  return layer_ptr;
+
+}
+
 INSTANTIATE_CLASS(Net);
 
 }  // namespace caffe
diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp
new file mode 100644 (file)
index 0000000..2adfb9c
--- /dev/null
@@ -0,0 +1,149 @@
+// Copyright 2014 kloudkl@github
+
+#include <google/protobuf/text_format.h>
+#include <leveldb/db.h>
+
+#include "gtest/gtest.h"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+
+template <typename Dtype>
+class NetTest : public ::testing::Test {
+ protected:
+  NetTest() : filename(NULL) {
+  };
+  virtual void SetUp() {
+    // Create the leveldb
+    filename = tmpnam(NULL); // get temp name
+    LOG(INFO) << "Using temporary leveldb " << filename;
+    leveldb::DB* db;
+    leveldb::Options options;
+    options.error_if_exists = true;
+    options.create_if_missing = true;
+    leveldb::Status status = leveldb::DB::Open(options, filename, &db);
+    CHECK(status.ok());
+    for (int i = 0; i < 5; ++i) {
+      Datum datum;
+      datum.set_label(i);
+      datum.set_channels(2);
+      datum.set_height(3);
+      datum.set_width(4);
+      std::string* data = datum.mutable_data();
+      for (int j = 0; j < 24; ++j) {
+        data->push_back((uint8_t)i);
+      }
+      stringstream ss;
+      ss << i;
+      db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
+    }
+    delete db;
+
+    const string& proto_prefix =
+        "name: 'TestNetwork' "
+        "layers: { "
+        "  layer { "
+        "    name: 'data' "
+        "    type: 'data' "
+        ;
+    const string& proto_suffix =
+        "    batchsize: 1 "
+        "  } "
+        "  top: 'data' "
+        "  top: 'label' "
+        "} "
+        "layers: { "
+        "  layer { "
+        "    name: 'innerproduct' "
+        "    type: 'innerproduct' "
+        "    num_output: 1000 "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "    bias_filler { "
+        "      type: 'constant' "
+        "      value: 0 "
+        "    } "
+        "    blobs_lr: 1. "
+        "    blobs_lr: 2. "
+        "    weight_decay: 1. "
+        "    weight_decay: 0. "
+        "  } "
+        "  bottom: 'data' "
+        "  top: 'innerproduct' "
+        "} "
+        "layers: { "
+        "  layer { "
+        "    name: 'loss' "
+        "    type: 'softmax_loss' "
+        "  } "
+        "  bottom: 'innerproduct' "
+        "  bottom: 'label' "
+        "} "
+        ;
+    proto = proto_prefix + "source: '" + string(this->filename) +
+        "' " + proto_suffix;
+  }
+
+  virtual ~NetTest() {
+  }
+
+  char* filename;
+  string proto;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(NetTest, Dtypes);
+
+TYPED_TEST(NetTest, TestHasBlob) {
+  NetParameter param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
+                                                      &param));
+  Net<TypeParam> net(param);
+  EXPECT_TRUE(net.HasBlob("data"));
+  EXPECT_TRUE(net.HasBlob("label"));
+  EXPECT_TRUE(net.HasBlob("innerproduct"));
+  EXPECT_FALSE(net.HasBlob("loss"));
+}
+
+TYPED_TEST(NetTest, TestGetBlob) {
+  NetParameter param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
+                                                      &param));
+  Net<TypeParam> net(param);
+  EXPECT_EQ(net.GetBlob("data"), net.blobs()[0]);
+  EXPECT_EQ(net.GetBlob("label"), net.blobs()[1]);
+  EXPECT_EQ(net.GetBlob("innerproduct"), net.blobs()[2]);
+  EXPECT_FALSE(net.GetBlob("loss"));
+}
+
+TYPED_TEST(NetTest, TestHasLayer) {
+  NetParameter param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
+                                                      &param));
+  Net<TypeParam> net(param);
+  EXPECT_TRUE(net.HasLayer("data"));
+  EXPECT_TRUE(net.HasLayer("innerproduct"));
+  EXPECT_TRUE(net.HasLayer("loss"));
+  EXPECT_FALSE(net.HasLayer("label"));
+}
+
+TYPED_TEST(NetTest, TestGetLayerByName) {
+  NetParameter param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
+                                                      &param));
+  Net<TypeParam> net(param);
+  EXPECT_EQ(net.GetLayerByName("data"), net.layers()[0]);
+  EXPECT_EQ(net.GetLayerByName("innerproduct"), net.layers()[1]);
+  EXPECT_EQ(net.GetLayerByName("loss"), net.layers()[2]);
+  EXPECT_FALSE(net.GetLayerByName("label"));
+}
+
+
+}