Fix bugs in the feature extraction example
authorKai Li <kaili_kloud@163.com>
Tue, 25 Feb 2014 18:47:08 +0000 (02:47 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:42 +0000 (23:04 +0800)
examples/demo_extract_features.cpp

index 479ce55..7385dab 100644 (file)
@@ -30,10 +30,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
   if (argc < num_required_args) {
     LOG(ERROR)<<
         "This program takes in a trained network and an input data layer, and then"
-        "  extract features of the input data produced by the net."
+        " extract features of the input data produced by the net.\n"
         "Usage: demo_extract_features  pretrained_net_param"
-        "  extract_feature_blob_name  data_prototxt  data_layer_name"
-        "  save_feature_leveldb_name  [CPU/GPU]  [DEVICE_ID=0]";
+        "  feature_extraction_proto_file  extract_feature_blob_name"
+        "  save_feature_leveldb_name  num_mini_batches  [CPU/GPU]  [DEVICE_ID=0]";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -58,86 +58,78 @@ int feature_extraction_pipeline(int argc, char** argv) {
   NetParameter pretrained_net_param;
 
   arg_pos = 0;  // the name of the executable
-  // We directly load the net param from trained file
   string pretrained_binary_proto(argv[++arg_pos]);
   ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
                           &pretrained_net_param);
+
+  // Expected prototxt contains at least one data layer such as
+   //  the layer data_layer_name and one feature blob such as the
+   //  fc7 top blob to extract features.
+   /*
+    layers {
+      layer {
+        name: "data_layer_name"
+        type: "data"
+        source: "/path/to/your/images/to/extract/feature/images_leveldb"
+        meanfile: "/path/to/your/image_mean.binaryproto"
+        batchsize: 128
+        cropsize: 227
+        mirror: false
+      }
+      top: "data_blob_name"
+      top: "label_blob_name"
+    }
+    layers {
+      layer {
+        name: "drop7"
+        type: "dropout"
+        dropout_ratio: 0.5
+      }
+      bottom: "fc7"
+      top: "fc7"
+    }
+    */
+  NetParameter feature_extraction_net_param;;
+  string feature_extraction_proto(argv[++arg_pos]);
+  ReadProtoFromTextFile(feature_extraction_proto,
+                        &feature_extraction_net_param);
   shared_ptr<Net<Dtype> > feature_extraction_net(
-      new Net<Dtype>(pretrained_net_param));
+      new Net<Dtype>(feature_extraction_net_param));
+  feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
 
   string extract_feature_blob_name(argv[++arg_pos]);
   if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
     LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
-    " in trained network " << pretrained_binary_proto;
+    " in the network " << feature_extraction_proto;
     return 1;
   }
 
-  // Expected prototxt contains at least one data layer to extract features.
-  /*
-   layers {
-   layer {
-   name: "data_layer_name"
-   type: "data"
-   source: "/path/to/your/images/to/extract/feature/images_leveldb"
-   meanfile: "/path/to/your/image_mean.binaryproto"
-   batchsize: 128
-   cropsize: 227
-   mirror: false
-   }
-   top: "data_blob_name"
-   top: "label_blob_name"
-   }
-   */
-  string data_prototxt(argv[++arg_pos]);
-  string data_layer_name(argv[++arg_pos]);
-  NetParameter data_net_param;
-  ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
-  LayerParameter data_layer_param;
-  int num_layer;
-  for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
-    if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
-      data_layer_param = data_net_param.layers(num_layer).layer();
-      break;
-    }
-  }
-  if (num_layer = data_net_param.layers_size()) {
-    LOG(ERROR) << "Unknown data layer name " << data_layer_name <<
-        " in prototxt " << data_prototxt;
-  }
-
   string save_feature_leveldb_name(argv[++arg_pos]);
   leveldb::DB* db;
   leveldb::Options options;
   options.error_if_exists = true;
   options.create_if_missing = true;
   options.write_buffer_size = 268435456;
-  LOG(INFO) << "Opening leveldb " << argv[3];
+  LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name;
   leveldb::Status status = leveldb::DB::Open(
       options, save_feature_leveldb_name.c_str(), &db);
   CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
 
+  int num_mini_batches = atoi(argv[++arg_pos]);
+
   LOG(ERROR)<< "Extacting Features";
-  DataLayer<Dtype> data_layer(data_layer_param);
-  vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
-  vector<Blob<Dtype>*> top_vec;
-  data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
-  int batch_index = 0;
-  int image_index = 0;
 
   Datum datum;
   leveldb::WriteBatch* batch = new leveldb::WriteBatch();
   const int max_key_str_length = 100;
   char key_str[max_key_str_length];
   int num_bytes_of_binary_code = sizeof(Dtype);
-  // TODO: DataLayer seem to rotate from the last record to the first
-  // how to judge that all the data record have been enumerated?
-  while (top_vec.size()) { // data_layer still outputs data
-    LOG(ERROR)<< "Batch " << batch_index << " feature extraction";
-    feature_extraction_net->Forward(top_vec);
+  vector<Blob<float>* > input_vec;
+  int image_index = 0;
+  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
+    feature_extraction_net->Forward(input_vec);
     const shared_ptr<Blob<Dtype> > feature_blob =
         feature_extraction_net->GetBlob(extract_feature_blob_name);
-
-    LOG(ERROR) << "Batch " << batch_index << " save extracted features";
     int num_features = feature_blob->num();
     int dim_features = feature_blob->count() / num_features;
     for (int n = 0; n < num_features; ++n) {
@@ -165,17 +157,14 @@ int feature_extraction_pipeline(int argc, char** argv) {
          batch = new leveldb::WriteBatch();
        }
     }
-    // write the last batch
-    if (image_index % 1000 != 0) {
-      db->Write(leveldb::WriteOptions(), batch);
-      LOG(ERROR) << "Extracted features of " << image_index << " query images.";
-      delete batch;
-      batch = new leveldb::WriteBatch();
-    }
-
-    data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
-    ++batch_index;
-  } //  while (top_vec.size()) {
+  } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
+  // write the last batch
+  if (image_index % 1000 != 0) {
+    db->Write(leveldb::WriteOptions(), batch);
+    LOG(ERROR) << "Extracted features of " << image_index << " query images.";
+    delete batch;
+    batch = new leveldb::WriteBatch();
+  }
 
   delete batch;
   delete db;