0766eea63e33aa0c252cd218585d995fe0a80ac2
[platform/upstream/caffeonacl.git] / tools / extract_features.cpp
1 // Copyright 2014 kloudkl@github
2
3 #include <stdio.h>  // for snprintf
4 #include <cuda_runtime.h>
5 #include <google/protobuf/text_format.h>
6 #include <leveldb/db.h>
7 #include <leveldb/write_batch.h>
8 #include <string>
9 #include <vector>
10
11 #include "caffe/blob.hpp"
12 #include "caffe/common.hpp"
13 #include "caffe/net.hpp"
14 #include "caffe/vision_layers.hpp"
15 #include "caffe/proto/caffe.pb.h"
16 #include "caffe/util/io.hpp"
17
18 using namespace caffe;  // NOLINT(build/namespaces)
19
20 template<typename Dtype>
21 int feature_extraction_pipeline(int argc, char** argv);
22
23 int main(int argc, char** argv) {
24   return feature_extraction_pipeline<float>(argc, argv);
25 //  return feature_extraction_pipeline<double>(argc, argv);
26 }
27
28 template<typename Dtype>
29 int feature_extraction_pipeline(int argc, char** argv) {
30   const int num_required_args = 6;
31   if (argc < num_required_args) {
32     LOG(ERROR)<<
33     "This program takes in a trained network and an input data layer, and then"
34     " extract features of the input data produced by the net.\n"
35     "Usage: demo_extract_features  pretrained_net_param"
36     "  feature_extraction_proto_file  extract_feature_blob_name"
37     "  save_feature_leveldb_name  num_mini_batches  [CPU/GPU]  [DEVICE_ID=0]";
38     return 1;
39   }
40   int arg_pos = num_required_args;
41
42   arg_pos = num_required_args;
43   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
44     LOG(ERROR)<< "Using GPU";
45     uint device_id = 0;
46     if (argc > arg_pos + 1) {
47       device_id = atoi(argv[arg_pos + 1]);
48       CHECK_GE(device_id, 0);
49     }
50     LOG(ERROR) << "Using Device_id=" << device_id;
51     Caffe::SetDevice(device_id);
52     Caffe::set_mode(Caffe::GPU);
53   } else {
54     LOG(ERROR) << "Using CPU";
55     Caffe::set_mode(Caffe::CPU);
56   }
57   Caffe::set_phase(Caffe::TEST);
58
59   NetParameter pretrained_net_param;
60
61   arg_pos = 0;  // the name of the executable
62   string pretrained_binary_proto(argv[++arg_pos]);
63   ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
64                           &pretrained_net_param);
65
66   // Expected prototxt contains at least one data layer such as
67   //  the layer data_layer_name and one feature blob such as the
68   //  fc7 top blob to extract features.
69   /*
70    layers {
71    layer {
72    name: "data_layer_name"
73    type: "data"
74    source: "/path/to/your/images/to/extract/feature/images_leveldb"
75    meanfile: "/path/to/your/image_mean.binaryproto"
76    batchsize: 128
77    cropsize: 227
78    mirror: false
79    }
80    top: "data_blob_name"
81    top: "label_blob_name"
82    }
83    layers {
84    layer {
85    name: "drop7"
86    type: "dropout"
87    dropout_ratio: 0.5
88    }
89    bottom: "fc7"
90    top: "fc7"
91    }
92    */
93   NetParameter feature_extraction_net_param;
94   string feature_extraction_proto(argv[++arg_pos]);
95   ReadProtoFromTextFile(feature_extraction_proto,
96                         &feature_extraction_net_param);
97   shared_ptr<Net<Dtype> > feature_extraction_net(
98       new Net<Dtype>(feature_extraction_net_param));
99   feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
100
101   string extract_feature_blob_name(argv[++arg_pos]);
102   CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
103       << "Unknown feature blob name " << extract_feature_blob_name
104       << " in the network " << feature_extraction_proto;
105
106   string save_feature_leveldb_name(argv[++arg_pos]);
107   leveldb::DB* db;
108   leveldb::Options options;
109   options.error_if_exists = true;
110   options.create_if_missing = true;
111   options.write_buffer_size = 268435456;
112   LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
113   leveldb::Status status = leveldb::DB::Open(options,
114                                              save_feature_leveldb_name.c_str(),
115                                              &db);
116   CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
117
118   int num_mini_batches = atoi(argv[++arg_pos]);
119
120   LOG(ERROR)<< "Extacting Features";
121
122   Datum datum;
123   leveldb::WriteBatch* batch = new leveldb::WriteBatch();
124   const int kMaxKeyStrLength = 100;
125   char key_str[kMaxKeyStrLength];
126   int num_bytes_of_binary_code = sizeof(Dtype);
127   vector<Blob<float>*> input_vec;
128   int image_index = 0;
129   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
130     feature_extraction_net->Forward(input_vec);
131     const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
132         ->blob_by_name(extract_feature_blob_name);
133     int num_features = feature_blob->num();
134     int dim_features = feature_blob->count() / num_features;
135     Dtype* feature_blob_data;
136     for (int n = 0; n < num_features; ++n) {
137       datum.set_height(dim_features);
138       datum.set_width(1);
139       datum.set_channels(1);
140       datum.clear_data();
141       datum.clear_float_data();
142       feature_blob_data = feature_blob->mutable_cpu_data() +
143           feature_blob->offset(n);
144       for (int d = 0; d < dim_features; ++d) {
145         datum.add_float_data(feature_blob_data[d]);
146       }
147       string value;
148       datum.SerializeToString(&value);
149       snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
150       batch->Put(string(key_str), value);
151       ++image_index;
152       if (image_index % 1000 == 0) {
153         db->Write(leveldb::WriteOptions(), batch);
154         LOG(ERROR)<< "Extracted features of " << image_index <<
155             " query images.";
156         delete batch;
157         batch = new leveldb::WriteBatch();
158       }
159     }  // for (int n = 0; n < num_features; ++n)
160   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
161   // write the last batch
162   if (image_index % 1000 != 0) {
163     db->Write(leveldb::WriteOptions(), batch);
164     LOG(ERROR)<< "Extracted features of " << image_index <<
165         " query images.";
166     delete batch;
167     batch = new leveldb::WriteBatch();
168   }
169
170   delete batch;
171   delete db;
172   LOG(ERROR)<< "Successfully extracted the features!";
173   return 0;
174 }
175