Enhance help, log message & format of the feature extraction example
authorKai Li <kaili_kloud@163.com>
Tue, 25 Feb 2014 19:46:32 +0000 (03:46 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:42 +0000 (23:04 +0800)
examples/demo_extract_features.cpp

index 7385dab..d16ee70 100644 (file)
@@ -15,7 +15,6 @@
 
 using namespace caffe;
 
-
 template<typename Dtype>
 int feature_extraction_pipeline(int argc, char** argv);
 
@@ -29,11 +28,11 @@ int feature_extraction_pipeline(int argc, char** argv) {
   const int num_required_args = 6;
   if (argc < num_required_args) {
     LOG(ERROR)<<
-        "This program takes in a trained network and an input data layer, and then"
-        " extract features of the input data produced by the net.\n"
-        "Usage: demo_extract_features  pretrained_net_param"
-        "  feature_extraction_proto_file  extract_feature_blob_name"
-        "  save_feature_leveldb_name  num_mini_batches  [CPU/GPU]  [DEVICE_ID=0]";
+    "This program takes in a trained network and an input data layer, and then"
+    " extract features of the input data produced by the net.\n"
+    "Usage: demo_extract_features  pretrained_net_param"
+    "  feature_extraction_proto_file  extract_feature_blob_name"
+    "  save_feature_leveldb_name  num_mini_batches  [CPU/GPU]  [DEVICE_ID=0]";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -63,33 +62,34 @@ int feature_extraction_pipeline(int argc, char** argv) {
                           &pretrained_net_param);
 
   // Expected prototxt contains at least one data layer such as
-   //  the layer data_layer_name and one feature blob such as the
-   //  fc7 top blob to extract features.
-   /*
-    layers {
-      layer {
-        name: "data_layer_name"
-        type: "data"
-        source: "/path/to/your/images/to/extract/feature/images_leveldb"
-        meanfile: "/path/to/your/image_mean.binaryproto"
-        batchsize: 128
-        cropsize: 227
-        mirror: false
-      }
-      top: "data_blob_name"
-      top: "label_blob_name"
-    }
-    layers {
-      layer {
-        name: "drop7"
-        type: "dropout"
-        dropout_ratio: 0.5
-      }
-      bottom: "fc7"
-      top: "fc7"
-    }
-    */
-  NetParameter feature_extraction_net_param;;
+  //  the layer data_layer_name and one feature blob such as the
+  //  fc7 top blob to extract features.
+  /*
+   layers {
+   layer {
+   name: "data_layer_name"
+   type: "data"
+   source: "/path/to/your/images/to/extract/feature/images_leveldb"
+   meanfile: "/path/to/your/image_mean.binaryproto"
+   batchsize: 128
+   cropsize: 227
+   mirror: false
+   }
+   top: "data_blob_name"
+   top: "label_blob_name"
+   }
+   layers {
+   layer {
+   name: "drop7"
+   type: "dropout"
+   dropout_ratio: 0.5
+   }
+   bottom: "fc7"
+   top: "fc7"
+   }
+   */
+  NetParameter feature_extraction_net_param;
+  ;
   string feature_extraction_proto(argv[++arg_pos]);
   ReadProtoFromTextFile(feature_extraction_proto,
                         &feature_extraction_net_param);
@@ -98,11 +98,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
   feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
 
   string extract_feature_blob_name(argv[++arg_pos]);
-  if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
-    LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
-    " in the network " << feature_extraction_proto;
-    return 1;
-  }
+  CHECK(feature_extraction_net->HasBlob(extract_feature_blob_name))
+      << "Unknown feature blob name " << extract_feature_blob_name
+      << " in the network " << feature_extraction_proto;
 
   string save_feature_leveldb_name(argv[++arg_pos]);
   leveldb::DB* db;
@@ -110,9 +108,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
   options.error_if_exists = true;
   options.create_if_missing = true;
   options.write_buffer_size = 268435456;
-  LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name;
-  leveldb::Status status = leveldb::DB::Open(
-      options, save_feature_leveldb_name.c_str(), &db);
+  LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
+  leveldb::Status status = leveldb::DB::Open(options,
+                                             save_feature_leveldb_name.c_str(),
+                                             &db);
   CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
 
   int num_mini_batches = atoi(argv[++arg_pos]);
@@ -124,51 +123,52 @@ int feature_extraction_pipeline(int argc, char** argv) {
   const int max_key_str_length = 100;
   char key_str[max_key_str_length];
   int num_bytes_of_binary_code = sizeof(Dtype);
-  vector<Blob<float>* > input_vec;
+  vector<Blob<float>*> input_vec;
   int image_index = 0;
   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
     feature_extraction_net->Forward(input_vec);
-    const shared_ptr<Blob<Dtype> > feature_blob =
-        feature_extraction_net->GetBlob(extract_feature_blob_name);
+    const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+        ->GetBlob(extract_feature_blob_name);
     int num_features = feature_blob->num();
     int dim_features = feature_blob->count() / num_features;
     for (int n = 0; n < num_features; ++n) {
-       datum.set_height(dim_features);
-       datum.set_width(1);
-       datum.set_channels(1);
-       datum.clear_data();
-       datum.clear_float_data();
-       string* datum_string = datum.mutable_data();
-       const Dtype* feature_blob_data = feature_blob->cpu_data();
-       for (int d = 0; d < dim_features; ++d) {
-         const char* data_byte = reinterpret_cast<const char*>(feature_blob_data + d);
-         for(int i = 0; i < num_bytes_of_binary_code; ++i) {
-           datum_string->push_back(data_byte[i]);
-         }
-       }
-       string value;
-       datum.SerializeToString(&value);
-       snprintf(key_str, max_key_str_length, "%d", image_index);
-       batch->Put(string(key_str), value);
-       if (++image_index % 1000 == 0) {
-         db->Write(leveldb::WriteOptions(), batch);
-         LOG(ERROR) << "Extracted features of " << image_index << " query images.";
-         delete batch;
-         batch = new leveldb::WriteBatch();
-       }
+      datum.set_height(dim_features);
+      datum.set_width(1);
+      datum.set_channels(1);
+      datum.clear_data();
+      datum.clear_float_data();
+      string* datum_string = datum.mutable_data();
+      const Dtype* feature_blob_data = feature_blob->cpu_data();
+      for (int d = 0; d < dim_features; ++d) {
+        const char* data_byte = reinterpret_cast<const char*>(feature_blob_data
+            + d);
+        for (int i = 0; i < num_bytes_of_binary_code; ++i) {
+          datum_string->push_back(data_byte[i]);
+        }
+      }
+      string value;
+      datum.SerializeToString(&value);
+      snprintf(key_str, max_key_str_length, "%d", image_index);
+      batch->Put(string(key_str), value);
+      if (++image_index % 1000 == 0) {
+        db->Write(leveldb::WriteOptions(), batch);
+        LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
+        delete batch;
+        batch = new leveldb::WriteBatch();
+      }
     }
-  } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
+  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
   // write the last batch
   if (image_index % 1000 != 0) {
     db->Write(leveldb::WriteOptions(), batch);
-    LOG(ERROR) << "Extracted features of " << image_index << " query images.";
+    LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
     delete batch;
     batch = new leveldb::WriteBatch();
   }
 
   delete batch;
   delete db;
-  LOG(ERROR)<< "Successfully ended!";
+  LOG(ERROR)<< "Successfully extracted the features!";
   return 0;
 }