From cd953c8245425fda3b94261f9c2e166a382c2430 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Sun, 23 Feb 2014 16:27:17 +0800 Subject: [PATCH] Add and test Net::HasBlob and GetBlob to simplify feature extraction Inspired by @kencoken's commit f36e715 https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b Related issues: --- include/caffe/net.hpp | 6 ++++++ src/caffe/net.cpp | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index b5a57b3..c6f968b 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -57,6 +57,11 @@ class Net { void CopyTrainedLayersFrom(const string trained_filename); // Writes the net to a proto. void ToProto(NetParameter* param, bool write_diff = false); + // HasBlob and GetBlob are inspired by + // https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b + // Access intermediary computation layers, testing with centre image only + bool HasBlob(const string& blob_name); + const shared_ptr > GetBlob(const string& blob_name); // returns the network name. inline const string& name() { return name_; } @@ -96,6 +101,7 @@ class Net { // layers. vector > > blobs_; vector blob_names_; + map blob_names_index_; vector blob_need_backward_; // bottom_vecs stores the vectors containing the input for each layer. // They don't actually host the blobs (blobs_ does), so we simply store diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 1837b07..5633088 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -162,6 +162,9 @@ void Net::Init(const NetParameter& in_param) { LOG(INFO) << "This network produces output " << *it; net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); } + for (size_t i = 0; i < blob_names_.size(); ++i) { + blob_names_index_[blob_names_[i]] = i; + } GetLearningRateAndWeightDecay(); LOG(INFO) << "Network initialization done."; LOG(INFO) << "Memory required for Data " << memory_used*sizeof(Dtype); @@ -327,6 +330,26 @@ void Net::Update() { } } +// HasBlob and GetBlob are inspired by +// https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b +// Access intermediary computation layers, testing with centre image only +template +bool Net::HasBlob(const string& blob_name) { + return blob_names_index_.find(blob_name) != blob_names_index_.end(); +} + +template +const shared_ptr > Net::GetBlob(const string& blob_name) { + shared_ptr > blob_ptr; + if (HasBlob(blob_name)) { + blob_ptr = blobs_[blob_names_index_[blob_name]]; + } else { + blob_ptr.reset(new Blob()); + LOG(ERROR) << "Unknown blob name " << blob_name; + } + return blob_ptr; +} + INSTANTIATE_CLASS(Net); } // namespace caffe -- 2.7.4