misc update
authorYangqing Jia <jiayq84@gmail.com>
Sat, 14 Sep 2013 20:55:26 +0000 (13:55 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Sat, 14 Sep 2013 20:55:26 +0000 (13:55 -0700)
src/caffeine/blob.cpp
src/caffeine/blob.hpp
src/caffeine/common.cpp
src/caffeine/common.hpp
src/caffeine/core_layers.hpp
src/caffeine/layer.cpp [moved from src/caffeine/base.cpp with 70% similarity]
src/caffeine/layer.hpp [moved from src/caffeine/base.hpp with 60% similarity]
src/caffeine/neuron_layer.cpp
src/caffeine/proto/layer_param.proto
src/caffeine/test_caffeine_main.cpp [moved from src/caffeine/test_caffeine.cpp with 100% similarity]
src/caffeine/test_syncedmem.cpp

index 9193d6e..80d4acf 100644 (file)
@@ -77,10 +77,41 @@ Dtype* Blob<Dtype>::mutable_gpu_diff() {
 }
 
 template <typename Dtype>
-void Blob<Dtype>::update() {
+void Blob<Dtype>::Update() {
   // not implemented yet.
-  
+  LOG(FATAL) << "not implemented";
+}
 
+template <typename Dtype>
+void Blob<Dtype>::FromProto(const BlobProto& proto) {
+  Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
+  // copy data
+  Dtype* data_vec = mutable_cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    data_vec[i] = proto.data(i);
+  }
+  Dtype* diff_vec = mutable_cpu_diff();
+  for (int i = 0; i < count_; ++i) {
+    diff_vec[i] = proto.diff(i);
+  }
+}
+
+template <typename Dtype>
+void Blob<Dtype>::ToProto(BlobProto* proto) {
+  proto->set_num(num_);
+  proto->set_channels(channels_);
+  proto->set_height(height_);
+  proto->set_width(width_);
+  proto->clear_data();
+  proto->clear_diff();
+  const Dtype* data_vec = cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    proto->add_data(data_vec[i]);
+  }
+  const Dtype* diff_vec = cpu_diff();
+  for (int i = 0; i < count_; ++i) {
+    proto->add_diff(diff_vec[i]);
+  }
 }
 
 template class Blob<float>;
index 5bf8f49..4c0bf0d 100644 (file)
@@ -3,6 +3,7 @@
 
 #include "caffeine/common.hpp"
 #include "caffeine/syncedmem.hpp"
+#include "caffeine/proto/layer_param.pb.h"
 
 namespace caffeine {
 
@@ -31,7 +32,9 @@ class Blob {
   Dtype* mutable_gpu_data();
   Dtype* mutable_cpu_diff();
   Dtype* mutable_gpu_diff();
-  void update();
+  void Update();
+  void FromProto(const BlobProto& proto);
+  void ToProto(BlobProto* proto);
  private:
   shared_ptr<SyncedMemory> data_;
   shared_ptr<SyncedMemory> diff_;
index c4062b7..0cee065 100644 (file)
@@ -28,8 +28,9 @@ Caffeine::Brew Caffeine::mode() {
   return Get().mode_;
 }
 
-Caffeine::Brew Caffeine::set_mode(Caffeine::Brew mode) {
+void Caffeine::set_mode(Caffeine::Brew mode) {
   Get().mode_ = mode;
 }
 
 }  // namespace caffeine
+
index 1aad02a..f3197c8 100644 (file)
@@ -28,7 +28,7 @@ class Caffeine {
   static cublasHandle_t cublas_handle();
   static Brew mode();
   // The setters for the variables
-  static Brew set_mode(Brew mode);
+  static void set_mode(Brew mode);
  private:
   Caffeine();
   static shared_ptr<Caffeine> singleton_;
@@ -36,6 +36,6 @@ class Caffeine {
   Brew mode_;
 };
 
-}
+}  // namespace caffeine
 
 #endif  // CAFFEINE_COMMON_HPP_
index 9c04875..7e969bb 100644 (file)
@@ -1,7 +1,7 @@
 #ifndef CAFFEINE_CORE_LAYERS_HPP_
 #define CAFFEINE_CORE_LAYERS_HPP_
 
-#include "caffeine/base.hpp"
+#include "caffeine/layer.hpp"
 
 namespace caffeine {
 
similarity index 70%
rename from src/caffeine/base.cpp
rename to src/caffeine/layer.cpp
index a925425..f832b54 100644 (file)
@@ -1,4 +1,4 @@
-#include "caffeine/base.hpp"
+#include "caffeine/layer.hpp"
 
 namespace caffeine {
 
@@ -6,7 +6,7 @@ namespace caffeine {
 // gpu specific implementations instead, and should not change these
 // functions.
 template <typename Dtype>
-inline void Layer<Dtype>::Forward(vector<const Blob<Dtype>*>& bottom,
+inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   switch(Caffeine::mode()) {
   case Caffeine::CPU:
@@ -21,14 +21,15 @@ inline void Layer<Dtype>::Forward(vector<const Blob<Dtype>*>& bottom,
 };
 
 template <typename Dtype>
-inline void Layer<Dtype>::Backward(vector<Blob<Dtype>*>& bottom,
-    vector<const Blob<Dtype>*>* top, bool propagate_down) {
+inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
   switch(Caffeine::mode()) {
   case Caffeine::CPU:
-    Backward_cpu(bottom, top, propagate_down);
+    return Backward_cpu(top, propagate_down, bottom);
     break;
   case Caffeine::GPU:
-    Backward_gpu(bottom, top, propagate_down);
+    return Backward_gpu(top, propagate_down, bottom);
     break;
   default:
     LOG(FATAL) << "Unknown caffeine mode.";
@@ -36,7 +37,7 @@ inline void Layer<Dtype>::Backward(vector<Blob<Dtype>*>& bottom,
 };
 
 template <typename Dtype>
-inline void Layer<Dtype>::Predict(vector<const Blob<Dtype>*>& bottom,
+inline void Layer<Dtype>::Predict(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   switch(Caffeine::mode()) {
   case Caffeine::CPU:
similarity index 60%
rename from src/caffeine/base.hpp
rename to src/caffeine/layer.hpp
index 2409ac2..cc33caf 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef CAFFEINE_BASE_H_
-#define CAFFEINE_BASE_H_
+#ifndef CAFFEINE_LAYER_H_
+#define CAFFEINE_LAYER_H_
 
 #include <vector>
 #include "caffeine/blob.hpp"
@@ -23,11 +23,12 @@ class Layer {
   // Forward, backward and predict wrappers. You should implement the cpu and
   // gpu specific implementations instead, and should not change these
   // functions.
-  inline void Forward(vector<const Blob<Dtype>*>& bottom,
+  inline void Forward(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  inline void Backward(vector<Blob<Dtype>*>& bottom,
-      vector<const Blob<Dtype>*>* top, bool propagate_down);
-  inline void Predict(vector<const Blob<Dtype>*>& bottom,
+  inline Dtype Backward(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down,
+      vector<Blob<Dtype>*>* bottom);
+  inline void Predict(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
  protected:
@@ -37,36 +38,40 @@ class Layer {
   vector<Blob<Dtype> > blobs;
 
   // Forward functions
-  virtual void Forward_cpu(vector<const Blob<Dtype>*>& bottom,
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) = 0;
   // If no gpu code is provided, we will simply use cpu code.
-  virtual void Forward_gpu(vector<const Blob<Dtype>*>& bottom,
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
     LOG(WARNING) << "Using CPU code as backup.";
     Forward_cpu(bottom, top);
   };
 
-  // Backward functions
-  virtual void Backward_cpu(vector<Blob<Dtype>*>& bottom,
-      vector<const Blob<Dtype>*>* top, bool propagate_down) = 0;
-  virtual void Backward_gpu(vector<Blob<Dtype>*>& bottom,
-      vector<const Blob<Dtype>*>* top, bool propagate_down) {
+  // Backward functions: the backward function will compute the gradients for
+  // any parameters and also for the bottom blobs if propagate_down is true.
+  // It will return the loss produced from this layer.
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down,
+      vector<Blob<Dtype>*>* bottom) = 0;
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down,
+      vector<Blob<Dtype>*>* bottom) {
     LOG(WARNING) << "Using CPU code as backup.";
-    Backward_cpu(bottom, top, propagate_down);
+    return Backward_cpu(top, propagate_down, bottom);
   };
 
   // Prediction functions: could be overridden, but the default behavior is to
   // simply call the forward functions.
-  virtual void Predict_cpu(vector<const Blob<Dtype>*>& bottom,
+  virtual void Predict_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) { Forward_cpu(bottom, top); };
   // For prediction, if there is no Predict_gpu, then there are two options:
   // to use predict_cpu as a backup, or to use forward_gpu (e.g. maybe the
   // author forgot to write what backup s/he wants?). Thus, we will require
   // the author to explicitly specify which fallback s/he wants.
-  virtual void Predict_gpu(vector<const Blob<Dtype>*>& bottom,
+  virtual void Predict_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) = 0;
 };  // class Layer
 
 }  // namespace caffeine
 
-#endif  // CAFFEINE_BASE_H_
+#endif  // CAFFEINE_LAYER_H_
index 20745b0..0b947e6 100644 (file)
@@ -1,4 +1,4 @@
-#include "caffeine/base.hpp"
+#include "caffeine/layer.hpp"
 
 namespace caffeine {
 
index 2bccc95..bcd23f3 100644 (file)
@@ -2,4 +2,13 @@ package caffeine;
 
 message LayerParameter {
   required string name = 1;
+}
+
+message BlobProto {
+  optional int32 num = 1 [default = 0];
+  optional int32 height = 2 [default = 0];
+  optional int32 width = 3 [default = 0];
+  optional int32 channels = 4 [default = 0];
+  repeated float data = 5;
+  repeated float diff = 6;
 }
\ No newline at end of file
index 6c45697..f0dc209 100644 (file)
@@ -37,6 +37,13 @@ TEST_F(SyncedMemoryTest, TestCPUWrite) {
   }
   const void* gpu_data = mem.gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
+  // check if values are the same
+  char* recovered_value = new char[10];
+  cudaMemcpy((void*)recovered_value, gpu_data, 10, cudaMemcpyDeviceToHost);
+  for (int i = 0; i < mem.size(); ++i) {
+    EXPECT_EQ(((char*)recovered_value)[i], 1);
+  }
+  delete[] recovered_value;
 }
 
 TEST_F(SyncedMemoryTest, TestGPUWrite) {