Add and test Net::HasBlob and GetBlob to simplify feature extraction
authorKai Li <kaili_kloud@163.com>
Sun, 23 Feb 2014 08:27:17 +0000 (16:27 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:41 +0000 (23:04 +0800)
Inspired by @kencoken's commit f36e715
https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b

Related issues:

include/caffe/net.hpp
src/caffe/net.cpp

index b5a57b3..c6f968b 100644 (file)
@@ -57,6 +57,11 @@ class Net {
   void CopyTrainedLayersFrom(const string trained_filename);
   // Writes the net to a proto.
   void ToProto(NetParameter* param, bool write_diff = false);
+  // HasBlob and GetBlob are inspired by
+  // https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
+  // Access intermediary computation layers, testing with centre image only
+  bool HasBlob(const string& blob_name);
+  const shared_ptr<Blob<Dtype> > GetBlob(const string& blob_name);
 
   // returns the network name.
   inline const string& name() { return name_; }
@@ -96,6 +101,7 @@ class Net {
   // layers.
   vector<shared_ptr<Blob<Dtype> > > blobs_;
   vector<string> blob_names_;
+  map<string, int> blob_names_index_;
   vector<bool> blob_need_backward_;
   // bottom_vecs stores the vectors containing the input for each layer.
   // They don't actually host the blobs (blobs_ does), so we simply store
index 1837b07..5633088 100644 (file)
@@ -162,6 +162,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     LOG(INFO) << "This network produces output " << *it;
     net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
   }
+  for (size_t i = 0; i < blob_names_.size(); ++i) {
+    blob_names_index_[blob_names_[i]] = i;
+  }
   GetLearningRateAndWeightDecay();
   LOG(INFO) << "Network initialization done.";
   LOG(INFO) << "Memory required for Data " << memory_used*sizeof(Dtype);
@@ -327,6 +330,26 @@ void Net<Dtype>::Update() {
   }
 }
 
+// HasBlob and GetBlob are inspired by
+// https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
+// Access intermediary computation layers, testing with centre image only
+template <typename Dtype>
+bool Net<Dtype>::HasBlob(const string& blob_name) {
+  return blob_names_index_.find(blob_name) != blob_names_index_.end();
+}
+
+template <typename Dtype>
+const shared_ptr<Blob<Dtype> > Net<Dtype>::GetBlob(const string& blob_name) {
+  shared_ptr<Blob<Dtype> > blob_ptr;
+  if (HasBlob(blob_name)) {
+    blob_ptr = blobs_[blob_names_index_[blob_name]];
+  } else {
+    blob_ptr.reset(new Blob<Dtype>());
+    LOG(ERROR) << "Unknown blob name " << blob_name;
+  }
+  return blob_ptr;
+}
+
 INSTANTIATE_CLASS(Net);
 
 }  // namespace caffe