1 // Copyright 2014 kloudkl@github
3 #include <stdio.h> // for snprintf
4 #include <cuda_runtime.h>
5 #include <google/protobuf/text_format.h>
6 #include <leveldb/db.h>
7 #include <leveldb/write_batch.h>
11 #include "caffe/blob.hpp"
12 #include "caffe/common.hpp"
13 #include "caffe/net.hpp"
14 #include "caffe/vision_layers.hpp"
15 #include "caffe/proto/caffe.pb.h"
16 #include "caffe/util/io.hpp"
18 using namespace caffe; // NOLINT(build/namespaces)
20 template<typename Dtype>
21 int feature_extraction_pipeline(int argc, char** argv);
23 int main(int argc, char** argv) {
24 return feature_extraction_pipeline<float>(argc, argv);
25 // return feature_extraction_pipeline<double>(argc, argv);
28 template<typename Dtype>
29 int feature_extraction_pipeline(int argc, char** argv) {
30 const int num_required_args = 6;
31 if (argc < num_required_args) {
33 "This program takes in a trained network and an input data layer, and then"
34 " extract features of the input data produced by the net.\n"
35 "Usage: demo_extract_features pretrained_net_param"
36 " feature_extraction_proto_file extract_feature_blob_name"
37 " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
40 int arg_pos = num_required_args;
42 arg_pos = num_required_args;
43 if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
44 LOG(ERROR)<< "Using GPU";
46 if (argc > arg_pos + 1) {
47 device_id = atoi(argv[arg_pos + 1]);
48 CHECK_GE(device_id, 0);
50 LOG(ERROR) << "Using Device_id=" << device_id;
51 Caffe::SetDevice(device_id);
52 Caffe::set_mode(Caffe::GPU);
54 LOG(ERROR) << "Using CPU";
55 Caffe::set_mode(Caffe::CPU);
57 Caffe::set_phase(Caffe::TEST);
59 NetParameter pretrained_net_param;
61 arg_pos = 0; // the name of the executable
62 string pretrained_binary_proto(argv[++arg_pos]);
63 ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
64 &pretrained_net_param);
66 // Expected prototxt contains at least one data layer such as
67 // the layer data_layer_name and one feature blob such as the
68 // fc7 top blob to extract features.
72 name: "data_layer_name"
74 source: "/path/to/your/images/to/extract/feature/images_leveldb"
75 meanfile: "/path/to/your/image_mean.binaryproto"
81 top: "label_blob_name"
93 NetParameter feature_extraction_net_param;
94 string feature_extraction_proto(argv[++arg_pos]);
95 ReadProtoFromTextFile(feature_extraction_proto,
96 &feature_extraction_net_param);
97 shared_ptr<Net<Dtype> > feature_extraction_net(
98 new Net<Dtype>(feature_extraction_net_param));
99 feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
101 string extract_feature_blob_name(argv[++arg_pos]);
102 CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
103 << "Unknown feature blob name " << extract_feature_blob_name
104 << " in the network " << feature_extraction_proto;
106 string save_feature_leveldb_name(argv[++arg_pos]);
108 leveldb::Options options;
109 options.error_if_exists = true;
110 options.create_if_missing = true;
111 options.write_buffer_size = 268435456;
112 LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
113 leveldb::Status status = leveldb::DB::Open(options,
114 save_feature_leveldb_name.c_str(),
116 CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
118 int num_mini_batches = atoi(argv[++arg_pos]);
120 LOG(ERROR)<< "Extacting Features";
123 leveldb::WriteBatch* batch = new leveldb::WriteBatch();
124 const int kMaxKeyStrLength = 100;
125 char key_str[kMaxKeyStrLength];
126 int num_bytes_of_binary_code = sizeof(Dtype);
127 vector<Blob<float>*> input_vec;
129 for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
130 feature_extraction_net->Forward(input_vec);
131 const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
132 ->blob_by_name(extract_feature_blob_name);
133 int num_features = feature_blob->num();
134 int dim_features = feature_blob->count() / num_features;
135 Dtype* feature_blob_data;
136 for (int n = 0; n < num_features; ++n) {
137 datum.set_height(dim_features);
139 datum.set_channels(1);
141 datum.clear_float_data();
142 feature_blob_data = feature_blob->mutable_cpu_data() +
143 feature_blob->offset(n);
144 for (int d = 0; d < dim_features; ++d) {
145 datum.add_float_data(feature_blob_data[d]);
148 datum.SerializeToString(&value);
149 snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
150 batch->Put(string(key_str), value);
152 if (image_index % 1000 == 0) {
153 db->Write(leveldb::WriteOptions(), batch);
154 LOG(ERROR)<< "Extracted features of " << image_index <<
157 batch = new leveldb::WriteBatch();
159 } // for (int n = 0; n < num_features; ++n)
160 } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
161 // write the last batch
162 if (image_index % 1000 != 0) {
163 db->Write(leveldb::WriteOptions(), batch);
164 LOG(ERROR)<< "Extracted features of " << image_index <<
167 batch = new leveldb::WriteBatch();
172 LOG(ERROR)<< "Successfully extracted the features!";