test loading from text file
authorYangqing Jia <jiayq84@gmail.com>
Sat, 28 Sep 2013 01:05:02 +0000 (18:05 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Sat, 28 Sep 2013 01:05:02 +0000 (18:05 -0700)
src/caffe/test/data/lenet.prototxt [new file with mode: 0644]
src/caffe/test/test_net_proto.cpp
src/caffe/util/io.cpp
src/caffe/util/io.hpp

diff --git a/src/caffe/test/data/lenet.prototxt b/src/caffe/test/data/lenet.prototxt
new file mode 100644 (file)
index 0000000..a7d578e
--- /dev/null
@@ -0,0 +1,113 @@
+name: "LeNet"
+bottom: "data"
+bottom: "label"
+layers {
+  layer {
+    name: "conv1"
+    type: "conv"
+    num_output: 20
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "data"
+  top: "conv1"
+}
+layers {
+  layer {
+    name: "pool1"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv1"
+  top: "pool1"
+}
+layers {
+  layer {
+    name: "conv2"
+    type: "conv"
+    num_output: 50
+    kernelsize: 5
+    stride: 1
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool1"
+  top: "conv2"
+}
+layers {
+  layer {
+    name: "pool2"
+    type: "pool"
+    kernelsize: 2
+    stride: 2
+    pool: MAX
+  }
+  bottom: "conv2"
+  top: "pool2"
+}
+layers {
+  layer {
+    name: "ip1"
+    type: "innerproduct"
+    num_output: 500
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "pool2"
+  top: "ip1"
+}
+layers {
+  layer {
+    name: "relu1"
+    type: "relu"
+  }
+  bottom: "ip1"
+  top: "relu1"
+}
+layers {
+  layer {
+    name: "ip2"
+    type: "innerproduct"
+    num_output: 10
+    weight_filler {
+      type: "xavier"
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+  bottom: "relu1"
+  top: "ip2"
+}
+layers {
+  layer {
+    name: "prob"
+    type: "softmax"
+  }
+  bottom: "ip2"
+  top: "prob"
+}
+layers {
+  layer {
+    name: "loss"
+    type: "multinomial_logistic_loss"
+  }
+  bottom: "prob"
+  bottom: "label"
+}
index 013bd67..b240346 100644 (file)
@@ -1,15 +1,19 @@
 // Copyright 2013 Yangqing Jia
 
-#include <cstring>
 #include <cuda_runtime.h>
+#include <fcntl.h>
 #include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <gtest/gtest.h>
 
+#include <cstring>
+
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/net.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
 
 #include "caffe/test/lenet.hpp"
 #include "caffe/test/test_caffe_main.hpp"
@@ -22,6 +26,11 @@ class NetProtoTest : public ::testing::Test {};
 typedef ::testing::Types<float, double> Dtypes;
 TYPED_TEST_CASE(NetProtoTest, Dtypes);
 
+TYPED_TEST(NetProtoTest, TestLoadFromText) {
+  NetParameter net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet.prototxt", &net_param);
+}
+
 TYPED_TEST(NetProtoTest, TestSetup) {
   NetParameter net_param;
   string lenet_string(kLENET);
@@ -65,7 +74,6 @@ TYPED_TEST(NetProtoTest, TestSetup) {
   caffe_net.Forward(bottom_vec, &top_vec);
   LOG(ERROR) << "Performing Backward";
   LOG(ERROR) << caffe_net.Backward();
-
 }
 
 }  // namespace caffe
index d827176..c1f5a96 100644 (file)
@@ -1,6 +1,9 @@
 // Copyright 2013 Yangqing Jia
 
 #include <stdint.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <opencv2/core/core.hpp>
 #include <opencv2/highgui/highgui.hpp>
 
@@ -15,6 +18,7 @@ using cv::Mat;
 using cv::Vec3b;
 using std::max;
 using std::string;
+using google::protobuf::io::FileInputStream;
 
 namespace caffe {
 
@@ -57,5 +61,13 @@ void WriteProtoToImage(const string& filename, const BlobProto& proto) {
   CHECK(cv::imwrite(filename, cv_img));
 }
 
+void ReadProtoFromTextFile(const char* filename,
+    ::google::protobuf::Message* proto) {
+  int fd = open(filename, O_RDONLY);
+  FileInputStream* input = new FileInputStream(fd);
+  CHECK(google::protobuf::TextFormat::Parse(input, proto));
+  delete input;
+  close(fd);
+}
 
 }  // namespace caffe
index f55de8a..90617be 100644 (file)
@@ -3,7 +3,10 @@
 #ifndef CAFFE_UTIL_IO_H_
 #define CAFFE_UTIL_IO_H_
 
+#include <google/protobuf/message.h>
+
 #include <string>
+
 #include "caffe/proto/caffe.pb.h"
 
 using std::string;
@@ -13,6 +16,12 @@ namespace caffe {
 void ReadImageToProto(const string& filename, BlobProto* proto);
 void WriteProtoToImage(const string& filename, const BlobProto& proto);
 
+void ReadProtoFromTextFile(const char* filename,
+    ::google::protobuf::Message* proto);
+inline void ReadProtoFromTextFile(const string& filename,
+    ::google::protobuf::Message* proto) {
+  ReadProtoFromTextFile(filename.c_str(), proto);
+}
 
 }  // namespace caffe