Made a major change: when initializing a network, the input size are no longer provid...
authorYangqing Jia <jiayq84@gmail.com>
Sat, 16 Nov 2013 00:58:53 +0000 (16:58 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Sat, 16 Nov 2013 00:58:53 +0000 (16:58 -0800)
13 files changed:
Makefile
examples/dump_network.cpp
examples/imagenet.prototxt
examples/imagenet_solver.prototxt
examples/imagenet_test.prototxt
examples/net_speed_benchmark.cpp
examples/test_net.cpp
include/caffe/net.hpp
python/caffe/imagenet/wrapper.py
python/caffe/pycaffe.cpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index ffcbad9..58c8c16 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
 # The makefile for caffe. Extremely hack.
 PROJECT := caffe
-TEST_GPUID := 1
+TEST_GPUID := 0
 
 # The target static library and shared library name
 NAME := lib$(PROJECT).so
index dc9466d..0d6e2d0 100644 (file)
@@ -50,7 +50,7 @@ int main(int argc, char** argv) {
     input_vec.push_back(input_blob.get());
   }
 
-  shared_ptr<Net<float> > caffe_net(new Net<float>(net_param, input_vec));
+  shared_ptr<Net<float> > caffe_net(new Net<float>(net_param));
   caffe_net->CopyTrainedLayersFrom(trained_net_param);
 
   string output_prefix(argv[4]);
index a821f43..5db585b 100644 (file)
@@ -3,8 +3,8 @@ layers {
   layer {
     name: "data"
     type: "data"
-    source: "/home/jiayq/caffe-train-leveldb"
-    meanfile: "/home/jiayq/ilsvrc2012_mean.binaryproto"
+    source: "/home/jiayq/Data/ILSVRC12/train-leveldb"
+    meanfile: "/home/jiayq/Data/ILSVRC12/image_mean.binaryproto"
     batchsize: 256
     cropsize: 227
     mirror: true
index 7d312ad..e00d3a5 100644 (file)
@@ -1,5 +1,4 @@
 train_net: "examples/imagenet.prototxt"
-test_net: "examples/imagenet_test.prototxt"
 test_iter: 1000
 test_interval: 1000
 base_lr: 0.01
@@ -11,4 +10,4 @@ max_iter: 4500000
 momentum: 0.9
 weight_decay: 0.0005
 snapshot: 10000
-snapshot_prefix: "alexnet_train"
\ No newline at end of file
+snapshot_prefix: "alexnet_train"
index 7a2d1a5..8494ae8 100644 (file)
@@ -3,9 +3,9 @@ layers {
   layer {
     name: "data"
     type: "data"
-    source: "/home/jiayq/caffe-val-leveldb"
-    meanfile: "/home/jiayq/ilsvrc2012_mean.binaryproto"
-    batchsize: 50
+    source: "/home/jiayq/Data/ILSVRC12/val-leveldb"
+    meanfile: "/home/jiayq/Data/ILSVRC12/image_mean.binaryproto"
+    batchsize: 200
     cropsize: 227
     mirror: false
   }
@@ -260,10 +260,18 @@ layers {
 }
 layers {
   layer {
+    name: "prob"
+    type: "softmax"
+  }
+  bottom: "fc8"
+  top: "prob"
+}
+layers {
+  layer {
     name: "accuracy"
     type: "accuracy"
   }
-  bottom: "fc8"
+  bottom: "prob"
   bottom: "label"
   top: "accuracy"
 }
\ No newline at end of file
index 97e8223..7a0b375 100644 (file)
@@ -26,12 +26,13 @@ int main(int argc, char** argv) {
   NetParameter net_param;
   ReadProtoFromTextFile(argv[1],
       &net_param);
-  vector<Blob<float>*> bottom_vec;
-  Net<float> caffe_net(net_param, bottom_vec);
+  Net<float> caffe_net(net_param);
 
   // Run the network without training.
   LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
+  // Note that for the speed benchmark, we will assume that the network does
+  // not take any input blobs.
+  caffe_net.Forward(vector<Blob<float>*>());
   LOG(ERROR) << "Performing Backward";
   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
 
index f1af817..5b8305a 100644 (file)
@@ -32,10 +32,9 @@ int main(int argc, char** argv) {
     Caffe::set_mode(Caffe::CPU);
   }
 
-  vector<Blob<float>*> bottom_vec;
   NetParameter test_net_param;
   ReadProtoFromTextFile(argv[1], &test_net_param);
-  Net<float> caffe_test_net(test_net_param, bottom_vec);
+  Net<float> caffe_test_net(test_net_param);
   NetParameter trained_net_param;
   ReadProtoFromBinaryFile(argv[2], &trained_net_param);
   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
@@ -44,9 +43,10 @@ int main(int argc, char** argv) {
   LOG(ERROR) << "Running " << total_iter << "Iterations.";
 
   double test_accuracy = 0;
+  vector<Blob<float>*> dummy_blob_input_vec;
   for (int i = 0; i < total_iter; ++i) {
     const vector<Blob<float>*>& result =
-        caffe_test_net.Forward(bottom_vec);
+        caffe_test_net.Forward(dummy_blob_input_vec);
     test_accuracy += result[0]->cpu_data()[0];
     LOG(ERROR) << "Batch " << i << ", accuracy: " << result[0]->cpu_data()[0];
   }
index 7c148ea..684d6c5 100644 (file)
@@ -22,16 +22,12 @@ namespace caffe {
 template <typename Dtype>
 class Net {
  public:
-  Net(const NetParameter& param, const vector<Blob<Dtype>* >& bottom);
-  Net(const NetParameter& param, const vector<int>& bottom);
-  Net(const string& param_file, const vector<Blob<Dtype>* >& bottom);
-  Net(const string& param_file, const vector<int>& bottom);
+  Net(const NetParameter& param);
   Net(const string& param_file);
   virtual ~Net() {}
 
-  // Initialize a network with the network parameter and the bottom vectors.
-  void Init(const NetParameter& param,
-      const vector<Blob<Dtype>* >& bottom);
+  // Initialize a network with the network parameter.
+  void Init(const NetParameter& param);
 
   // Run forward with the input blobs already fed separately. You can get the
   // input blobs using input_blobs().
index 8fc22d5..f71cb95 100644 (file)
@@ -76,8 +76,7 @@ class ImageNetClassifier(object):
       num = 1
     else:
       num = 10
-    self.caffenet = caffe.CaffeNet(model_def_file, pretrained_model,
-        [num, 3, CROPPED_DIM, CROPPED_DIM])
+    self.caffenet = caffe.CaffeNet(model_def_file, pretrained_model)
     self._output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)]
     self._center_only = center_only
 
index 9038760..9d3e04a 100644 (file)
@@ -25,13 +25,8 @@ using boost::python::object;
 // A simple wrapper over CaffeNet that runs the forward process.
 struct CaffeNet
 {
-  CaffeNet(string param_file, string pretrained_param_file,
-      list bottom) {
-    vector<int> bottom_vec;
-    for (int i = 0; i < len(bottom); ++i) {
-      bottom_vec.push_back(extract<int>(bottom[i]));
-    }
-    net_.reset(new Net<float>(param_file, bottom_vec));
+  CaffeNet(string param_file, string pretrained_param_file) {
+    net_.reset(new Net<float>(param_file));
     net_->CopyTrainedLayersFrom(pretrained_param_file);
   }
 
@@ -112,7 +107,7 @@ struct CaffeNet
 BOOST_PYTHON_MODULE(pycaffe)
 {
   boost::python::class_<CaffeNet>(
-      "CaffeNet", boost::python::init<string, string, list>())
+      "CaffeNet", boost::python::init<string, string>())
       .def("Forward", &CaffeNet::Forward)
       .def("set_mode_cpu", &CaffeNet::set_mode_cpu)
       .def("set_mode_gpu", &CaffeNet::set_mode_gpu)
index 6c793d7..5d7e447 100644 (file)
@@ -17,73 +17,34 @@ using std::set;
 namespace caffe {
 
 template <typename Dtype>
-Net<Dtype>::Net(const NetParameter& param,
-    const vector<Blob<Dtype>* >& bottom) {
-  Init(param, bottom);
-}
-
-template <typename Dtype>
-Net<Dtype>::Net(const NetParameter& param, const vector<int>& bottom) {
-  CHECK_EQ(bottom.size() % 4, 0);
-  vector<Blob<Dtype>* > bottom_blobs;
-  for (int i = 0; i < bottom.size(); i += 4) {
-    bottom_blobs.push_back(
-        new Blob<Dtype>(bottom[i], bottom[i+1], bottom[i+2], bottom[i+3]));
-  }
-  Init(param, bottom_blobs);
-  for (int i = 0; i < bottom_blobs.size(); ++i) {
-    delete bottom_blobs[i];
-  }
-}
-
-template <typename Dtype>
-Net<Dtype>::Net(const string& param_file,
-    const vector<Blob<Dtype>* >& bottom) {
-  NetParameter param;
-  ReadProtoFromTextFile(param_file, &param);
-  Init(param, bottom);
-}
-
-template <typename Dtype>
-Net<Dtype>::Net(const string& param_file, const vector<int>& bottom) {
-  CHECK_EQ(bottom.size() % 4, 0);
-  NetParameter param;
-  ReadProtoFromTextFile(param_file, &param);
-  vector<Blob<Dtype>* > bottom_blobs;
-  for (int i = 0; i < bottom.size(); i += 4) {
-    bottom_blobs.push_back(
-        new Blob<Dtype>(bottom[i], bottom[i+1], bottom[i+2], bottom[i+3]));
-  }
-  Init(param, bottom_blobs);
-  for (int i = 0; i < bottom_blobs.size(); ++i) {
-    delete bottom_blobs[i];
-  }
+Net<Dtype>::Net(const NetParameter& param) {
+  Init(param);
 }
 
 template <typename Dtype>
 Net<Dtype>::Net(const string& param_file) {
   NetParameter param;
   ReadProtoFromTextFile(param_file, &param);
-  Init(param, vector<Blob<Dtype>* >());
+  Init(param);
 }
 
 template <typename Dtype>
-void Net<Dtype>::Init(const NetParameter& param,
-    const vector<Blob<Dtype>* >& bottom) {
+void Net<Dtype>::Init(const NetParameter& param) {
   // Basically, build all the layers and set up its connections.
   name_ = param.name();
   map<string, int> blob_name_to_idx;
   set<string> available_blobs;
   int num_layers = param.layers_size();
-  CHECK_EQ(bottom.size(), param.input_size())
-      << "Incorrect bottom blob size.";
+  CHECK_EQ(param.input_size() * 4, param.input_dim_size())
+      << "Incorrect bottom blob dimension specifications.";
   // set the input blobs
   for (int i = 0; i < param.input_size(); ++i) {
     const string& blob_name = param.input(i);
-    CHECK_GT(bottom[i]->count(), 0);
     shared_ptr<Blob<Dtype> > blob_pointer(
-        new Blob<Dtype>(bottom[i]->num(), bottom[i]->channels(),
-            bottom[i]->height(), bottom[i]->width()));
+        new Blob<Dtype>(param.input_dim(i * 4),
+                        param.input_dim(i * 4 + 1),
+                        param.input_dim(i * 4 + 2),
+                        param.input_dim(i * 4 + 3)));
     blobs_.push_back(blob_pointer);
     blob_names_.push_back(blob_name);
     blob_need_backward_.push_back(false);
index d2f2e53..cac42bd 100644 (file)
@@ -102,7 +102,12 @@ message LayerConnection {
 message NetParameter {
   optional string name = 1; // consider giving the network a name
   repeated LayerConnection layers = 2; // a bunch of layers.
-  repeated string input = 3; // The input to the network
+  // The input blobs to the network.
+  repeated string input = 3;
+  // The dim of the input blobs. For each input blob there should be four
+  // values specifying the num, channels, height and width of the input blob.
+  // Thus, there should be a total of (4 * #input) numbers.
+  repeated int32 input_dim = 4;
 }
 
 message SolverParameter {
index 3562960..5c25641 100644 (file)
@@ -23,16 +23,13 @@ Solver<Dtype>::Solver(const SolverParameter& param)
   // Scaffolding code
   NetParameter train_net_param;
   ReadProtoFromTextFile(param_.train_net(), &train_net_param);
-  // For the training network, there should be no input - so we simply create
-  // a dummy bottom_vec instance to initialize the networks.
-  vector<Blob<Dtype>*> bottom_vec;
   LOG(INFO) << "Creating training net.";
-  net_.reset(new Net<Dtype>(train_net_param, bottom_vec));
+  net_.reset(new Net<Dtype>(train_net_param));
   if (param_.has_test_net()) {
     LOG(INFO) << "Creating testing net.";
     NetParameter test_net_param;
     ReadProtoFromTextFile(param_.test_net(), &test_net_param);
-    test_net_.reset(new Net<Dtype>(test_net_param, bottom_vec));
+    test_net_.reset(new Net<Dtype>(test_net_param));
     CHECK_GT(param_.test_iter(), 0);
     CHECK_GT(param_.test_interval(), 0);
   }