misc update
authorYangqing Jia <jiayq84@gmail.com>
Fri, 8 Nov 2013 21:50:12 +0000 (13:50 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 8 Nov 2013 21:50:12 +0000 (13:50 -0800)
Makefile
include/caffe/common.hpp
include/caffe/net.hpp
include/caffe/util/io.hpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/util/io.cpp

index ab8c870..93cf537 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -41,10 +41,12 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
 
 # define inclue and libaries
 # We put src here just for gtest
-INCLUDE_DIRS := ./src ./include /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
+INCLUDE_DIRS := ./src ./include /usr/local/include $(CUDA_INCLUDE_DIR) \
+       $(MKL_INCLUDE_DIR) /usr/include/python2.7
 LIBRARY_DIRS := /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
 LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
-       glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system
+       glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system \
+       python2.7 boost_python
 WARNINGS := -Wall
 
 COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
index af42772..e7c5abe 100644 (file)
@@ -115,6 +115,7 @@ class Caffe {
   DISABLE_COPY_AND_ASSIGN(Caffe);
 };
 
+
 }  // namespace caffe
 
 #endif  // CAFFE_COMMON_HPP_
index 9bbfd37..14d80be 100644 (file)
@@ -22,10 +22,22 @@ namespace caffe {
 template <typename Dtype>
 class Net {
  public:
-  Net(const NetParameter& param,
+  Net(const NetParameter& param, const vector<Blob<Dtype>* >& bottom);
+  Net(const NetParameter& param, const vector<int>& bottom);
+  Net(const string& param_file, const vector<Blob<Dtype>* >& bottom);
+  Net(const string& param_file, const vector<int>& bottom);
+  Net(const string& param_file);
+  virtual ~Net() {}
+
+  // Initialize a network with the network parameter and the bottom vectors.
+  void Init(const NetParameter& param,
       const vector<Blob<Dtype>* >& bottom);
-  ~Net() {}
+
+  // Run forward using a set of bottom blobs, and return the result.
   const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom);
+  // Run forward using a serialized BlobProtoVector and return the result
+  // as a serialized BlobProtoVector
+  string Forward(const string& input_blob_protos);
   // The network backward should take no input and output, since it solely
   // computes the gradient w.r.t the parameters, and the data has already
   // been provided during the forward pass.
@@ -39,6 +51,7 @@ class Net {
   // For an already initialized net, CopyTrainedLayersFrom() copies the already
   // trained layers from another net parameter instance.
   void CopyTrainedLayersFrom(const NetParameter& param);
+  void CopyTrainedLayersFrom(const string trained_filename);
   // Writes the net to a proto.
   void ToProto(NetParameter* param, bool write_diff = false);
 
index e37e740..f14b1ee 100644 (file)
@@ -40,7 +40,13 @@ inline void WriteProtoToBinaryFile(
   WriteProtoToBinaryFile(proto, filename.c_str());
 }
 
-bool ReadImageToDatum(const string& filename, const int label, Datum* datum);
+bool ReadImageToDatum(const string& filename, const int label,
+    const int height, const int width, Datum* datum);
+
+inline bool ReadImageToDatum(const string& filename, const int label,
+    Datum* datum) {
+  ReadImageToDatum(filename, label, 0, 0, datum);
+}
 
 }  // namespace caffe
 
index 0c344fa..dec4203 100644 (file)
@@ -5,9 +5,12 @@
 #include <string>
 #include <vector>
 
+#include <boost/python.hpp>
+
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/layer.hpp"
 #include "caffe/net.hpp"
+#include "caffe/util/io.hpp"
 
 using std::pair;
 using std::map;
@@ -18,6 +21,57 @@ namespace caffe {
 template <typename Dtype>
 Net<Dtype>::Net(const NetParameter& param,
     const vector<Blob<Dtype>* >& bottom) {
+  Init(param, bottom);
+}
+
+template <typename Dtype>
+Net<Dtype>::Net(const NetParameter& param, const vector<int>& bottom) {
+  CHECK_EQ(bottom.size() % 4, 0);
+  vector<Blob<Dtype>* > bottom_blobs;
+  for (int i = 0; i < bottom.size(); i += 4) {
+    bottom_blobs.push_back(
+        new Blob<Dtype>(bottom[i], bottom[i+1], bottom[i+2], bottom[i+3]));
+  }
+  Init(param, bottom_blobs);
+  for (int i = 0; i < bottom_blobs.size(); ++i) {
+    delete bottom_blobs[i];
+  }
+}
+
+template <typename Dtype>
+Net<Dtype>::Net(const string& param_file,
+    const vector<Blob<Dtype>* >& bottom) {
+  NetParameter param;
+  ReadProtoFromTextFile(param_file, &param);
+  Init(param, bottom);
+}
+
+template <typename Dtype>
+Net<Dtype>::Net(const string& param_file, const vector<int>& bottom) {
+  CHECK_EQ(bottom.size() % 4, 0);
+  NetParameter param;
+  ReadProtoFromTextFile(param_file, &param);
+  vector<Blob<Dtype>* > bottom_blobs;
+  for (int i = 0; i < bottom.size(); i += 4) {
+    bottom_blobs.push_back(
+        new Blob<Dtype>(bottom[i], bottom[i+1], bottom[i+2], bottom[i+3]));
+  }
+  Init(param, bottom_blobs);
+  for (int i = 0; i < bottom_blobs.size(); ++i) {
+    delete bottom_blobs[i];
+  }
+}
+
+template <typename Dtype>
+Net<Dtype>::Net(const string& param_file) {
+  NetParameter param;
+  ReadProtoFromTextFile(param_file, &param);
+  Init(param, vector<Blob<Dtype>* >());
+}
+
+template <typename Dtype>
+void Net<Dtype>::Init(const NetParameter& param,
+    const vector<Blob<Dtype>* >& bottom) {
   // Basically, build all the layers and set up its connections.
   name_ = param.name();
   map<string, int> blob_name_to_idx;
@@ -182,6 +236,29 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
   return net_output_blobs_;
 }
 
+
+template <typename Dtype>
+string Net<Dtype>::Forward(const string& input_blob_protos) {
+  BlobProtoVector blob_proto_vec;
+  blob_proto_vec.ParseFromString(input_blob_protos);
+  CHECK_EQ(blob_proto_vec.blobs_size(), net_input_blob_indices_.size())
+      << "Incorrect input size.";
+  for (int i = 0; i < blob_proto_vec.blobs_size(); ++i) {
+    blobs_[net_input_blob_indices_[i]]->FromProto(blob_proto_vec.blobs(i));
+  }
+  for (int i = 0; i < layers_.size(); ++i) {
+    layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
+  }
+  blob_proto_vec.Clear();
+  for (int i = 0; i < layers_.size(); ++i) {
+    net_output_blobs_[i]->ToProto(blob_proto_vec.add_blobs());
+  }
+  string output;
+  blob_proto_vec.SerializeToString(&output);
+  return output;
+}
+
+
 template <typename Dtype>
 Dtype Net<Dtype>::Backward() {
   Dtype loss = 0;
@@ -226,6 +303,13 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
 }
 
 template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
+  NetParameter param;
+  ReadProtoFromBinaryFile(trained_filename, &param);
+  CopyTrainedLayersFrom(param);
+}
+
+template <typename Dtype>
 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
   param->Clear();
   param->set_name(name_);
index 939df91..d2f2e53 100644 (file)
@@ -11,6 +11,12 @@ message BlobProto {
   repeated float diff = 6 [packed=true];
 }
 
+// The BlobProtoVector is simply a way to pass multiple blobproto instances
+// around.
+message BlobProtoVector {
+  repeated BlobProto blobs = 1;
+}
+
 message Datum {
   optional int32 channels = 1;
   optional int32 height = 2;
index 9bca5be..c4682f5 100644 (file)
@@ -65,14 +65,22 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
   CHECK(proto.SerializeToOstream(&output));
 }
 
-
-bool ReadImageToDatum(const string& filename, const int label, Datum* datum) {
+bool ReadImageToDatum(const string& filename, const int label,
+    const int height, const int width, Datum* datum) {
   cv::Mat cv_img;
-  cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+  if (height > 0 && width > 0) {
+    cv::Mat cv_img_origin = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+    cv::resize(cv_img_origin, cv_img, cv::Size(height, width));
+  } else {
+    cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+  }
   if (!cv_img.data) {
     LOG(ERROR) << "Could not open or find file " << filename;
     return false;
   }
+  if (height > 0 && width > 0) {
+
+  }
   datum->set_channels(3);
   datum->set_height(cv_img.rows);
   datum->set_width(cv_img.cols);