void CopyTrainedLayersFrom(const string trained_filename);
// Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false);
- // HasBlob and GetBlob are inspired by
- // https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
- // Access intermediary computation layers, testing with centre image only
- bool HasBlob(const string& blob_name);
- const shared_ptr<Blob<Dtype> > GetBlob(const string& blob_name);
- bool HasLayer(const string& layer_name);
- const shared_ptr<Layer<Dtype> > GetLayerByName(const string& layer_name);
// returns the network name.
inline const string& name() { return name_; }
inline int num_outputs() { return net_output_blobs_.size(); }
inline vector<Blob<Dtype>*>& input_blobs() { return net_input_blobs_; }
inline vector<Blob<Dtype>*>& output_blobs() { return net_output_blobs_; }
+ // has_blob and blob_by_name are inspired by
+ // https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
+ // Access intermediary computation layers, testing with centre image only
+ bool has_blob(const string& blob_name);
+ const shared_ptr<Blob<Dtype> > blob_by_name(const string& blob_name);
+ bool has_layer(const string& layer_name);
+ const shared_ptr<Layer<Dtype> > layer_by_name(const string& layer_name);
protected:
// Function to get misc parameters, e.g. the learning rate multiplier and
}
}
-// HasBlob and GetBlob are inspired by
-// https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
-// Access intermediary computation layers, testing with centre image only
template <typename Dtype>
-bool Net<Dtype>::HasBlob(const string& blob_name) {
+bool Net<Dtype>::has_blob(const string& blob_name) {
return blob_names_index_.find(blob_name) != blob_names_index_.end();
}
template <typename Dtype>
-const shared_ptr<Blob<Dtype> > Net<Dtype>::GetBlob(const string& blob_name) {
+const shared_ptr<Blob<Dtype> > Net<Dtype>::blob_by_name(const string& blob_name) {
shared_ptr<Blob<Dtype> > blob_ptr;
- if (HasBlob(blob_name)) {
+ if (has_blob(blob_name)) {
blob_ptr = blobs_[blob_names_index_[blob_name]];
} else {
blob_ptr.reset((Blob<Dtype>*)(NULL));
}
template <typename Dtype>
-bool Net<Dtype>::HasLayer(const string& layer_name) {
+bool Net<Dtype>::has_layer(const string& layer_name) {
return layer_names_index_.find(layer_name) != layer_names_index_.end();
}
template <typename Dtype>
-const shared_ptr<Layer<Dtype> > Net<Dtype>::GetLayerByName(const string& layer_name) {
+const shared_ptr<Layer<Dtype> > Net<Dtype>::layer_by_name(const string& layer_name) {
shared_ptr<Layer<Dtype> > layer_ptr;
- if (HasLayer(layer_name)) {
+ if (has_layer(layer_name)) {
layer_ptr = layers_[layer_names_index_[layer_name]];
} else {
layer_ptr.reset((Layer<Dtype>*)(NULL));
LOG(ERROR) << "Unknown layer name " << layer_name;
}
return layer_ptr;
-
}
INSTANTIATE_CLASS(Net);
// Copyright 2014 kloudkl@github
+#include <sstream>
#include <google/protobuf/text_format.h>
#include <leveldb/db.h>
for (int j = 0; j < 24; ++j) {
data->push_back((uint8_t)i);
}
- stringstream ss;
+ std::stringstream ss;
ss << i;
db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
}
CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
¶m));
Net<TypeParam> net(param);
- EXPECT_TRUE(net.HasBlob("data"));
- EXPECT_TRUE(net.HasBlob("label"));
- EXPECT_TRUE(net.HasBlob("innerproduct"));
- EXPECT_FALSE(net.HasBlob("loss"));
+ EXPECT_TRUE(net.has_blob("data"));
+ EXPECT_TRUE(net.has_blob("label"));
+ EXPECT_TRUE(net.has_blob("innerproduct"));
+ EXPECT_FALSE(net.has_blob("loss"));
}
TYPED_TEST(NetTest, TestGetBlob) {
CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
¶m));
Net<TypeParam> net(param);
- EXPECT_EQ(net.GetBlob("data"), net.blobs()[0]);
- EXPECT_EQ(net.GetBlob("label"), net.blobs()[1]);
- EXPECT_EQ(net.GetBlob("innerproduct"), net.blobs()[2]);
- EXPECT_FALSE(net.GetBlob("loss"));
+ EXPECT_EQ(net.blob_by_name("data"), net.blobs()[0]);
+ EXPECT_EQ(net.blob_by_name("label"), net.blobs()[1]);
+ EXPECT_EQ(net.blob_by_name("innerproduct"), net.blobs()[2]);
+ EXPECT_FALSE(net.blob_by_name("loss"));
}
TYPED_TEST(NetTest, TestHasLayer) {
CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
¶m));
Net<TypeParam> net(param);
- EXPECT_TRUE(net.HasLayer("data"));
- EXPECT_TRUE(net.HasLayer("innerproduct"));
- EXPECT_TRUE(net.HasLayer("loss"));
- EXPECT_FALSE(net.HasLayer("label"));
+ EXPECT_TRUE(net.has_layer("data"));
+ EXPECT_TRUE(net.has_layer("innerproduct"));
+ EXPECT_TRUE(net.has_layer("loss"));
+ EXPECT_FALSE(net.has_layer("label"));
}
TYPED_TEST(NetTest, TestGetLayerByName) {
CHECK(google::protobuf::TextFormat::ParseFromString(this->proto,
¶m));
Net<TypeParam> net(param);
- EXPECT_EQ(net.GetLayerByName("data"), net.layers()[0]);
- EXPECT_EQ(net.GetLayerByName("innerproduct"), net.layers()[1]);
- EXPECT_EQ(net.GetLayerByName("loss"), net.layers()[2]);
- EXPECT_FALSE(net.GetLayerByName("label"));
+ EXPECT_EQ(net.layer_by_name("data"), net.layers()[0]);
+ EXPECT_EQ(net.layer_by_name("innerproduct"), net.layers()[1]);
+ EXPECT_EQ(net.layer_by_name("loss"), net.layers()[2]);
+ EXPECT_FALSE(net.layer_by_name("label"));
}
new Net<Dtype>(real_valued_feature_net_param));
string feature_blob_name(argv[++arg_pos]);
- CHECK(real_valued_feature_net->HasBlob(feature_blob_name))
+ CHECK(real_valued_feature_net->has_blob(feature_blob_name))
<< "Unknown feature blob name " << feature_blob_name << " in the network "
<< real_valued_feature_prototxt;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
real_valued_feature_net->Forward(input_vec);
const shared_ptr<Blob<Dtype> > feature_blob = real_valued_feature_net
- ->GetBlob(feature_blob_name);
+ ->blob_by_name(feature_blob_name);
feature_blob_vector.push_back(feature_blob);
}
shared_ptr<Blob<Dtype> > feature_binary_codes(new Blob<Dtype>());
feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
string extract_feature_blob_name(argv[++arg_pos]);
- CHECK(feature_extraction_net->HasBlob(extract_feature_blob_name))
+ CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
<< "Unknown feature blob name " << extract_feature_blob_name
<< " in the network " << feature_extraction_proto;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
- ->GetBlob(extract_feature_blob_name);
+ ->blob_by_name(extract_feature_blob_name);
int num_features = feature_blob->num();
int dim_features = feature_blob->count() / num_features;
Dtype* feature_blob_data;