misc add
authorYangqing Jia <jiayq84@gmail.com>
Sat, 14 Sep 2013 05:14:58 +0000 (22:14 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Sat, 14 Sep 2013 05:14:58 +0000 (22:14 -0700)
src/caffeine/base.cpp [new file with mode: 0644]
src/caffeine/base.hpp [new file with mode: 0644]
src/caffeine/common.cpp [new file with mode: 0644]
src/caffeine/test_common.cpp [new file with mode: 0644]

diff --git a/src/caffeine/base.cpp b/src/caffeine/base.cpp
new file mode 100644 (file)
index 0000000..eefef73
--- /dev/null
@@ -0,0 +1,53 @@
+#include "caffeine/base.hpp"
+
+namespace caffeine {
+
+// Forward, backward and predict wrappers. You should implement the cpu and
+// gpu specific implementations instead, and should not change these
+// functions.
+template <typename Dtype>
+inline void Layer<Dtype>::Forward(vector<const Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  switch(Caffeine::mode()) {
+  case Caffeine::CPU:
+    Forward_cpu(bottom, top);
+    break;
+  case Caffeine::GPU:
+    Forward_gpu(bottom, top);
+    break;
+  default:
+    CHECK(false);
+  }
+};
+
+template <typename Dtype>
+inline void Layer<Dtype>::Backward(vector<Blob<Dtype>*>& bottom,
+    vector<const Blob<Dtype>*>* top, bool propagate_down) {
+  switch(Caffeine::mode()) {
+  case Caffeine::CPU:
+    Backward_cpu(bottom, top, propagate_down);
+    break;
+  case Caffeine::GPU:
+    Backward_gpu(bottom, top, propagate_down);
+    break;
+  default:
+    CHECK(false);
+  }
+};
+
+template <typename Dtype>
+inline void Layer<Dtype>::Predict(vector<const Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  switch(Caffeine::mode()) {
+  case Caffeine::CPU:
+    Predict_cpu(bottom, top);
+    break;
+  case Caffeine::GPU:
+    Predict_gpu(bottom, top);
+    break;
+  default:
+    CHECK(false);
+  }
+};
+
+}  // namespace caffeine
diff --git a/src/caffeine/base.hpp b/src/caffeine/base.hpp
new file mode 100644 (file)
index 0000000..2409ac2
--- /dev/null
@@ -0,0 +1,72 @@
+#ifndef CAFFEINE_BASE_H_
+#define CAFFEINE_BASE_H_
+
+#include <vector>
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/proto/layer_param.pb.h"
+
+using std::vector;
+
+namespace caffeine {
+
+template <typename Dtype>
+class Layer {
+ public:
+  explicit Layer(const LayerParameter& param)
+    : layer_param_(param) {};
+  virtual ~Layer();
+  // SetUp: your function should implement this.
+  virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+
+  // Forward, backward and predict wrappers. You should implement the cpu and
+  // gpu specific implementations instead, and should not change these
+  // functions.
+  inline void Forward(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  inline void Backward(vector<Blob<Dtype>*>& bottom,
+      vector<const Blob<Dtype>*>* top, bool propagate_down);
+  inline void Predict(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  // The protobuf that stores the layer parameters
+  LayerParameter layer_param_;
+  // The vector that stores the parameters as a set of blobs.
+  vector<Blob<Dtype> > blobs;
+
+  // Forward functions
+  virtual void Forward_cpu(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+  // If no gpu code is provided, we will simply use cpu code.
+  virtual void Forward_gpu(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+    LOG(WARNING) << "Using CPU code as backup.";
+    Forward_cpu(bottom, top);
+  };
+
+  // Backward functions
+  virtual void Backward_cpu(vector<Blob<Dtype>*>& bottom,
+      vector<const Blob<Dtype>*>* top, bool propagate_down) = 0;
+  virtual void Backward_gpu(vector<Blob<Dtype>*>& bottom,
+      vector<const Blob<Dtype>*>* top, bool propagate_down) {
+    LOG(WARNING) << "Using CPU code as backup.";
+    Backward_cpu(bottom, top, propagate_down);
+  };
+
+  // Prediction functions: could be overridden, but the default behavior is to
+  // simply call the forward functions.
+  virtual void Predict_cpu(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) { Forward_cpu(bottom, top); };
+  // For prediction, if there is no Predict_gpu, then there are two options:
+  // to use predict_cpu as a backup, or to use forward_gpu (e.g. maybe the
+  // author forgot to write what backup s/he wants?). Thus, we will require
+  // the author to explicitly specify which fallback s/he wants.
+  virtual void Predict_gpu(vector<const Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+};  // class Layer
+
+}  // namespace caffeine
+
+#endif  // CAFFEINE_BASE_H_
diff --git a/src/caffeine/common.cpp b/src/caffeine/common.cpp
new file mode 100644 (file)
index 0000000..c4062b7
--- /dev/null
@@ -0,0 +1,35 @@
+#include "caffeine/common.hpp"
+
+namespace caffeine {
+
+shared_ptr<Caffeine> Caffeine::singleton_;
+
+Caffeine::Caffeine()
+    : mode_(Caffeine::CPU) {
+  CUBLAS_CHECK(cublasCreate(&cublas_handle_));
+}
+
+Caffeine::~Caffeine() {
+  if (!cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
+};
+
+Caffeine& Caffeine::Get() {
+  if (!singleton_) {
+    singleton_.reset(new Caffeine());
+  }
+  return *singleton_;
+};
+
+cublasHandle_t Caffeine::cublas_handle() {
+  return Get().cublas_handle_;
+};
+
+Caffeine::Brew Caffeine::mode() {
+  return Get().mode_;
+}
+
+Caffeine::Brew Caffeine::set_mode(Caffeine::Brew mode) {
+  Get().mode_ = mode;
+}
+
+}  // namespace caffeine
diff --git a/src/caffeine/test_common.cpp b/src/caffeine/test_common.cpp
new file mode 100644 (file)
index 0000000..a2fc544
--- /dev/null
@@ -0,0 +1,21 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/common.hpp"
+
+namespace caffeine {
+
+class CommonTest : public ::testing::Test {};
+
+TEST_F(CommonTest, TestCublasHandler) {
+  EXPECT_TRUE(Caffeine::cublas_handle());
+}
+
+TEST_F(CommonTest, TestBrewMode) {
+ EXPECT_EQ(Caffeine::mode(), Caffeine::CPU);
+ Caffeine::set_mode(Caffeine::GPU);
+ EXPECT_EQ(Caffeine::mode(), Caffeine::GPU);
+}
+
+}