lenet training code
authorYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 23:32:11 +0000 (16:32 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 23:32:11 +0000 (16:32 -0700)
12 files changed:
src/caffe/layer.hpp
src/caffe/layer_factory.hpp
src/caffe/layers/loss_layer.cu
src/caffe/net.cpp
src/caffe/test/data/lenet.prototxt
src/caffe/test/data/lenet_test.prototxt [new file with mode: 0644]
src/caffe/test/data/lenet_traintest.prototxt [new file with mode: 0644]
src/caffe/test/data/linear_regression.prototxt [new file with mode: 0644]
src/caffe/test/test_solver_mnist.cpp [new file with mode: 0644]
src/caffe/util/io.cpp
src/caffe/util/io.hpp
src/caffe/vision_layers.hpp

index b82f038..f575935 100644 (file)
@@ -40,7 +40,7 @@ class Layer {
   }
 
   // Writes the layer parameter to a protocol buffer
-  void ToProto(LayerParameter* param, bool write_diff = false);
+  virtual void ToProto(LayerParameter* param, bool write_diff = false);
 
  protected:
   // The protobuf that stores the layer parameters
index e9f8dbb..b0ecb9d 100644 (file)
@@ -19,7 +19,9 @@ namespace caffe {
 template <typename Dtype>
 Layer<Dtype>* GetLayer(const LayerParameter& param) {
   const std::string& type = param.type();
-  if (type == "conv") {
+  if (type == "accuracy") {
+    return new AccuracyLayer<Dtype>(param);
+  } else if (type == "conv") {
     return new ConvolutionLayer<Dtype>(param);
   } else if (type == "data") {
     return new DataLayer<Dtype>(param);
index 0c09a5d..9aedc3d 100644 (file)
@@ -1,10 +1,11 @@
 // Copyright 2013 Yangqing Jia
+#include <algorithm>
+#include <cmath>
+#include <cfloat>
 
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/math_functions.hpp"
-#include <algorithm>
-#include <cmath>
 
 using std::max;
 
@@ -75,8 +76,47 @@ Dtype EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   return loss;
 }
 
+template <typename Dtype>
+void AccuracyLayer<Dtype>::SetUp(
+  const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 2) << "Accuracy Layer takes two blobs as input.";
+  CHECK_EQ(top->size(), 1) << "Accuracy Layer takes 1 output.";
+  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
+      << "The data and label should have the same number.";
+  CHECK_EQ(bottom[1]->channels(), 1);
+  CHECK_EQ(bottom[1]->height(), 1);
+  CHECK_EQ(bottom[1]->width(), 1);
+  (*top)[0]->Reshape(1, 1, 1, 1);
+}
+
+template <typename Dtype>
+void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  Dtype accuracy = 0;
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  const Dtype* bottom_label = bottom[1]->cpu_data();
+  int num = bottom[0]->num();
+  int dim = bottom[0]->count() / bottom[0]->num();
+  for (int i = 0; i < num; ++i) {
+    Dtype maxval = -FLT_MAX;
+    int max_id = 0;
+    for (int j = 0; j < dim; ++j) {
+      if (bottom_data[i * dim + j] > maxval) {
+        maxval = bottom_data[i * dim + j];
+        max_id = j;
+      }
+    }
+    if (max_id == (int)bottom_label[i]) {
+      ++accuracy;
+    }
+  }
+  accuracy /= num;
+  LOG(INFO) << "Accuracy: " << accuracy;
+  (*top)[0]->mutable_cpu_data()[0] = accuracy;
+}
+
 INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
 INSTANTIATE_CLASS(EuclideanLossLayer);
-
+INSTANTIATE_CLASS(AccuracyLayer);
 
 }  // namespace caffe
index bc699fa..2dd9b56 100644 (file)
@@ -157,12 +157,13 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
     param->add_input(blob_names_[net_input_blob_indices_[i]]);
   }
+  LOG(INFO) << "Serializing " << layers_.size() << " layers";
   for (int i = 0; i < layers_.size(); ++i) {
     LayerConnection* layer_connection = param->add_layers();
-    for (int j = 0; j < bottom_id_vecs_[i].size(); ++i) {
+    for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {
       layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
     }
-    for (int j = 0; j < top_id_vecs_[i].size(); ++i) {
+    for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
       layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
     }
     LayerParameter* layer_parameter = layer_connection->mutable_layer();
index ac5c1ca..5dd4ef7 100644 (file)
@@ -4,7 +4,7 @@ layers {
     name: "mnist"
     type: "data"
     source: "caffe/test/data/mnist-train-leveldb"
-    batchsize: 32
+    batchsize: 64
     scale: 0.00390625
   }
   top: "data"
diff --git a/src/caffe/test/data/lenet_test.prototxt b/src/caffe/test/data/lenet_test.prototxt
new file mode 100644 (file)
index 0000000..447709c
--- /dev/null
@@ -0,0 +1,123 @@
+name: "LeNet-test"
+layers {
+  layer {
+    name: "mnist"
+    type: "data"
+    source: "caffe/test/data/mnist-test-leveldb"
+    batchsize: 100
+    scale: 0.00390625
+  }
+  top: "data"
+  top: "label"
+}
+layers {
+  layer {
+    name: "conv1"
+    type: "conv"
+    num_output: 20
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "data"
+  top: "conv1"
+}
+layers {
+  layer {
+    name: "pool1"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv1"
+  top: "pool1"
+}
+layers {
+  layer {
+    name: "conv2"
+    type: "conv"
+    num_output: 50
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool1"
+  top: "conv2"
+}
+layers {
+  layer {
+    name: "pool2"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv2"
+  top: "pool2"
+}
+layers {
+  layer {
+    name: "ip1"
+    type: "innerproduct"
+    num_output: 500
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool2"
+  top: "ip1"
+}
+layers {
+  layer {
+    name: "relu1"
+    type: "relu"
+  }
+  bottom: "ip1"
+  top: "relu1"
+}
+layers {
+  layer {
+    name: "ip2"
+    type: "innerproduct"
+    num_output: 10
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "relu1"
+  top: "ip2"
+}
+layers {
+  layer {
+    name: "prob"
+    type: "softmax"
+  }
+  bottom: "ip2"
+  top: "prob"
+}
+layers {
+  layer {
+    name: "accuracy"
+    type: "accuracy"
+  }
+  bottom: "prob"
+  bottom: "label"
+  top: "accuracy"
+}
diff --git a/src/caffe/test/data/lenet_traintest.prototxt b/src/caffe/test/data/lenet_traintest.prototxt
new file mode 100644 (file)
index 0000000..2bb4320
--- /dev/null
@@ -0,0 +1,123 @@
+name: "LeNet-test"
+layers {
+  layer {
+    name: "mnist"
+    type: "data"
+    source: "caffe/test/data/mnist-train-leveldb"
+    batchsize: 100
+    scale: 0.00390625
+  }
+  top: "data"
+  top: "label"
+}
+layers {
+  layer {
+    name: "conv1"
+    type: "conv"
+    num_output: 20
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "data"
+  top: "conv1"
+}
+layers {
+  layer {
+    name: "pool1"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv1"
+  top: "pool1"
+}
+layers {
+  layer {
+    name: "conv2"
+    type: "conv"
+    num_output: 50
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool1"
+  top: "conv2"
+}
+layers {
+  layer {
+    name: "pool2"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv2"
+  top: "pool2"
+}
+layers {
+  layer {
+    name: "ip1"
+    type: "innerproduct"
+    num_output: 500
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool2"
+  top: "ip1"
+}
+layers {
+  layer {
+    name: "relu1"
+    type: "relu"
+  }
+  bottom: "ip1"
+  top: "relu1"
+}
+layers {
+  layer {
+    name: "ip2"
+    type: "innerproduct"
+    num_output: 10
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "relu1"
+  top: "ip2"
+}
+layers {
+  layer {
+    name: "prob"
+    type: "softmax"
+  }
+  bottom: "ip2"
+  top: "prob"
+}
+layers {
+  layer {
+    name: "accuracy"
+    type: "accuracy"
+  }
+  bottom: "prob"
+  bottom: "label"
+  top: "accuracy"
+}
diff --git a/src/caffe/test/data/linear_regression.prototxt b/src/caffe/test/data/linear_regression.prototxt
new file mode 100644 (file)
index 0000000..d509a55
--- /dev/null
@@ -0,0 +1,34 @@
+name: "linear_regression_net"
+layers {
+  layer {
+    name: "datalayer"
+    type: "data"
+    source: "caffe/test/data/simple-linear-regression-leveldb"
+    batchsize: 32
+  }
+  top: "data"
+  top: "label"
+}
+layers {
+  layer {
+    name: "ip"
+    type: "innerproduct"
+    num_output: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "data"
+  top: "ip"
+}
+layers {
+  layer {
+    name: "loss"
+    type: "euclidean_loss"
+  }
+  bottom: "ip"
+  bottom: "label"
+}
diff --git a/src/caffe/test/test_solver_mnist.cpp b/src/caffe/test/test_solver_mnist.cpp
new file mode 100644 (file)
index 0000000..4c8d3fd
--- /dev/null
@@ -0,0 +1,108 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <gtest/gtest.h>
+
+#include <cstring>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+class MNISTSolverTest : public ::testing::Test {};
+
+typedef ::testing::Types<float> Dtypes;
+TYPED_TEST_CASE(MNISTSolverTest, Dtypes);
+
+TYPED_TEST(MNISTSolverTest, TestSolve) {
+  Caffe::set_mode(Caffe::GPU);
+
+  NetParameter net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet.prototxt",
+      &net_param);
+  vector<Blob<TypeParam>*> bottom_vec;
+  Net<TypeParam> caffe_net(net_param, bottom_vec);
+
+  // Run the network without training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+
+  SolverParameter solver_param;
+  solver_param.set_base_lr(0.01);
+  solver_param.set_display(0);
+  solver_param.set_max_iter(6000);
+  solver_param.set_lr_policy("inv");
+  solver_param.set_gamma(0.0001);
+  solver_param.set_power(0.75);
+  solver_param.set_momentum(0.9);
+
+  LOG(ERROR) << "Starting Optimization";
+  SGDSolver<TypeParam> solver(solver_param);
+  solver.Solve(&caffe_net);
+  LOG(ERROR) << "Optimization Done.";
+
+  // Run the network after training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  TypeParam loss = caffe_net.Backward();
+  LOG(ERROR) << "Final loss: " << loss;
+  EXPECT_LE(loss, 0.5);
+
+  NetParameter trained_net_param;
+  caffe_net.ToProto(&trained_net_param);
+  // LOG(ERROR) << "Writing to disk.";
+  // WriteProtoToBinaryFile(trained_net_param,
+  //     "caffe/test/data/lenet_trained.prototxt");
+
+  NetParameter traintest_net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
+      &traintest_net_param);
+  Net<TypeParam> caffe_traintest_net(traintest_net_param, bottom_vec);
+  caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
+
+  // Test run
+  double train_accuracy = 0;
+  int batch_size = traintest_net_param.layers(0).layer().batchsize();
+  for (int i = 0; i < 60000 / batch_size; ++i) {
+    const vector<Blob<TypeParam>*>& result =
+        caffe_traintest_net.Forward(bottom_vec);
+    train_accuracy += result[0]->cpu_data()[0];
+  }
+  train_accuracy /= 60000 / batch_size;
+  LOG(ERROR) << "Train accuracy:" << train_accuracy;
+  EXPECT_GE(train_accuracy, 0.98);
+
+  NetParameter test_net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
+  Net<TypeParam> caffe_test_net(test_net_param, bottom_vec);
+  caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
+
+  // Test run
+  double test_accuracy = 0;
+  batch_size = test_net_param.layers(0).layer().batchsize();
+  for (int i = 0; i < 10000 / batch_size; ++i) {
+    const vector<Blob<TypeParam>*>& result =
+        caffe_test_net.Forward(bottom_vec);
+    test_accuracy += result[0]->cpu_data()[0];
+  }
+  test_accuracy /= 10000 / batch_size;
+  LOG(ERROR) << "Test accuracy:" << test_accuracy;
+  EXPECT_GE(test_accuracy, 0.98);
+}
+
+}  // namespace caffe
index ca20962..0d4f9bb 100644 (file)
@@ -9,6 +9,8 @@
 
 #include <algorithm>
 #include <string>
+#include <iostream>
+#include <fstream>
 
 #include "caffe/common.hpp"
 #include "caffe/util/io.hpp"
 
 using cv::Mat;
 using cv::Vec3b;
+using std::fstream;
+using std::ios;
 using std::max;
 using std::string;
 using google::protobuf::io::FileInputStream;
+using google::protobuf::io::FileOutputStream;
 
 namespace caffe {
 
@@ -80,4 +85,22 @@ void ReadProtoFromTextFile(const char* filename,
   close(fd);
 }
 
+void WriteProtoToTextFile(const Message& proto, const char* filename) {
+  int fd = open(filename, O_WRONLY);
+  FileOutputStream* output = new FileOutputStream(fd);
+  CHECK(google::protobuf::TextFormat::Print(proto, output));
+  delete output;
+  close(fd);
+}
+
+void ReadProtoFromBinaryFile(const char* filename, Message* proto) {
+  fstream input(filename, ios::in | ios::binary);
+  CHECK(proto->ParseFromIstream(&input));
+}
+
+void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
+  fstream output(filename, ios::out | ios::trunc | ios::binary);
+  CHECK(proto.SerializeToOstream(&output));
+}
+
 }  // namespace caffe
index 29f7f41..57beef1 100644 (file)
@@ -11,6 +11,7 @@
 #include "caffe/proto/caffe.pb.h"
 
 using std::string;
+using ::google::protobuf::Message;
 
 namespace caffe {
 
@@ -33,12 +34,30 @@ inline void WriteBlobToImage(const string& filename, const Blob<Dtype>& blob) {
 }
 
 void ReadProtoFromTextFile(const char* filename,
-    ::google::protobuf::Message* proto);
+    Message* proto);
 inline void ReadProtoFromTextFile(const string& filename,
-    ::google::protobuf::Message* proto) {
+    Message* proto) {
   ReadProtoFromTextFile(filename.c_str(), proto);
 }
 
+void WriteProtoToTextFile(const Message& proto, const char* filename);
+inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
+  WriteProtoToTextFile(proto, filename.c_str());
+}
+
+void ReadProtoFromBinaryFile(const char* filename,
+    Message* proto);
+inline void ReadProtoFromBinaryFile(const string& filename,
+    Message* proto) {
+  ReadProtoFromBinaryFile(filename.c_str(), proto);
+}
+
+void WriteProtoToBinaryFile(const Message& proto, const char* filename);
+inline void WriteProtoToBinaryFile(const Message& proto, const string& filename) {
+  WriteProtoToBinaryFile(proto, filename.c_str());
+}
+
+
 }  // namespace caffe
 
 #endif   // CAFFE_UTIL_IO_H_
index 29d2eb3..3aa43b2 100644 (file)
@@ -323,6 +323,26 @@ class EuclideanLossLayer : public Layer<Dtype> {
   Blob<Dtype> difference_;
 };
 
+template <typename Dtype>
+class AccuracyLayer : public Layer<Dtype> {
+ public:
+  explicit AccuracyLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  // The loss layer will do nothing during forward - all computation are
+  // carried out in the backward pass.
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  // The accuracy layer should not be used to compute backward operations.
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    NOT_IMPLEMENTED;
+    return Dtype(0.);
+  }
+};
 
 }  // namespace caffe