using namespace caffe;
-
template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv);
const int num_required_args = 6;
if (argc < num_required_args) {
LOG(ERROR)<<
- "This program takes in a trained network and an input data layer, and then"
- " extract features of the input data produced by the net.\n"
- "Usage: demo_extract_features pretrained_net_param"
- " feature_extraction_proto_file extract_feature_blob_name"
- " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
+ "This program takes in a trained network and an input data layer, and then"
+ " extract features of the input data produced by the net.\n"
+ "Usage: demo_extract_features pretrained_net_param"
+ " feature_extraction_proto_file extract_feature_blob_name"
+ " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
return 1;
}
int arg_pos = num_required_args;
&pretrained_net_param);
// Expected prototxt contains at least one data layer such as
- // the layer data_layer_name and one feature blob such as the
- // fc7 top blob to extract features.
- /*
- layers {
- layer {
- name: "data_layer_name"
- type: "data"
- source: "/path/to/your/images/to/extract/feature/images_leveldb"
- meanfile: "/path/to/your/image_mean.binaryproto"
- batchsize: 128
- cropsize: 227
- mirror: false
- }
- top: "data_blob_name"
- top: "label_blob_name"
- }
- layers {
- layer {
- name: "drop7"
- type: "dropout"
- dropout_ratio: 0.5
- }
- bottom: "fc7"
- top: "fc7"
- }
- */
- NetParameter feature_extraction_net_param;;
+ // the layer data_layer_name and one feature blob such as the
+ // fc7 top blob to extract features.
+ /*
+ layers {
+ layer {
+ name: "data_layer_name"
+ type: "data"
+ source: "/path/to/your/images/to/extract/feature/images_leveldb"
+ meanfile: "/path/to/your/image_mean.binaryproto"
+ batchsize: 128
+ cropsize: 227
+ mirror: false
+ }
+ top: "data_blob_name"
+ top: "label_blob_name"
+ }
+ layers {
+ layer {
+ name: "drop7"
+ type: "dropout"
+ dropout_ratio: 0.5
+ }
+ bottom: "fc7"
+ top: "fc7"
+ }
+ */
+ NetParameter feature_extraction_net_param;
+ ;
string feature_extraction_proto(argv[++arg_pos]);
ReadProtoFromTextFile(feature_extraction_proto,
&feature_extraction_net_param);
feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
string extract_feature_blob_name(argv[++arg_pos]);
- if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
- LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
- " in the network " << feature_extraction_proto;
- return 1;
- }
+ CHECK(feature_extraction_net->HasBlob(extract_feature_blob_name))
+ << "Unknown feature blob name " << extract_feature_blob_name
+ << " in the network " << feature_extraction_proto;
string save_feature_leveldb_name(argv[++arg_pos]);
leveldb::DB* db;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
- LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name;
- leveldb::Status status = leveldb::DB::Open(
- options, save_feature_leveldb_name.c_str(), &db);
+ LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
+ leveldb::Status status = leveldb::DB::Open(options,
+ save_feature_leveldb_name.c_str(),
+ &db);
CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
int num_mini_batches = atoi(argv[++arg_pos]);
const int max_key_str_length = 100;
char key_str[max_key_str_length];
int num_bytes_of_binary_code = sizeof(Dtype);
- vector<Blob<float>* > input_vec;
+ vector<Blob<float>*> input_vec;
int image_index = 0;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
- const shared_ptr<Blob<Dtype> > feature_blob =
- feature_extraction_net->GetBlob(extract_feature_blob_name);
+ const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+ ->GetBlob(extract_feature_blob_name);
int num_features = feature_blob->num();
int dim_features = feature_blob->count() / num_features;
for (int n = 0; n < num_features; ++n) {
- datum.set_height(dim_features);
- datum.set_width(1);
- datum.set_channels(1);
- datum.clear_data();
- datum.clear_float_data();
- string* datum_string = datum.mutable_data();
- const Dtype* feature_blob_data = feature_blob->cpu_data();
- for (int d = 0; d < dim_features; ++d) {
- const char* data_byte = reinterpret_cast<const char*>(feature_blob_data + d);
- for(int i = 0; i < num_bytes_of_binary_code; ++i) {
- datum_string->push_back(data_byte[i]);
- }
- }
- string value;
- datum.SerializeToString(&value);
- snprintf(key_str, max_key_str_length, "%d", image_index);
- batch->Put(string(key_str), value);
- if (++image_index % 1000 == 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR) << "Extracted features of " << image_index << " query images.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
+ datum.set_height(dim_features);
+ datum.set_width(1);
+ datum.set_channels(1);
+ datum.clear_data();
+ datum.clear_float_data();
+ string* datum_string = datum.mutable_data();
+ const Dtype* feature_blob_data = feature_blob->cpu_data();
+ for (int d = 0; d < dim_features; ++d) {
+ const char* data_byte = reinterpret_cast<const char*>(feature_blob_data
+ + d);
+ for (int i = 0; i < num_bytes_of_binary_code; ++i) {
+ datum_string->push_back(data_byte[i]);
+ }
+ }
+ string value;
+ datum.SerializeToString(&value);
+ snprintf(key_str, max_key_str_length, "%d", image_index);
+ batch->Put(string(key_str), value);
+ if (++image_index % 1000 == 0) {
+ db->Write(leveldb::WriteOptions(), batch);
+ LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
+ delete batch;
+ batch = new leveldb::WriteBatch();
+ }
}
- } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
+ } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
if (image_index % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR) << "Extracted features of " << image_index << " query images.";
+ LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
delete batch;
batch = new leveldb::WriteBatch();
}
delete batch;
delete db;
- LOG(ERROR)<< "Successfully ended!";
+ LOG(ERROR)<< "Successfully extracted the features!";
return 0;
}