misc update
authorYangqing Jia <jiayq84@gmail.com>
Tue, 17 Sep 2013 18:25:50 +0000 (11:25 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 17 Sep 2013 18:25:50 +0000 (11:25 -0700)
12 files changed:
src/caffeine/blob.cpp
src/caffeine/blob.hpp
src/caffeine/common.hpp
src/caffeine/filler.hpp
src/caffeine/layers/dropout_layer.cu [moved from src/caffeine/dropout_layer.cu with 97% similarity]
src/caffeine/layers/inner_product_layer.cu [new file with mode: 0644]
src/caffeine/layers/neuron_layer.cpp [moved from src/caffeine/neuron_layer.cpp with 87% similarity]
src/caffeine/layers/relu_layer.cu [moved from src/caffeine/relu_layer.cu with 97% similarity]
src/caffeine/proto/layer_param.proto
src/caffeine/test/test_blob.cpp
src/caffeine/test/test_neuron_layer.cpp
src/caffeine/vision_layers.hpp

index 80d4acf..ab56551 100644 (file)
@@ -7,25 +7,25 @@
 namespace caffeine {
 
 template <typename Dtype>
-void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
-    const int width) {
+void Blob<Dtype>::Reshape(const int num, const int height,
+    const int width, const int channels) {
   CHECK_GT(num, 0);
-  CHECK_GT(channels, 0);
   CHECK_GT(height, 0);
   CHECK_GT(width, 0);
+  CHECK_GT(channels, 0);
   num_ = num;
-  channels_ = channels;
   height_ = height;
   width_ = width;
-  count_ = num_ * channels_ * height_ * width_;
+  channels_ = channels;
+  count_ = num_ * height_ * width_ * channels_;
   data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
   diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
 }
 
 template <typename Dtype>
-Blob<Dtype>::Blob(const int num, const int channels, const int height,
-    const int width) {
-  Reshape(num, channels, height, width);
+Blob<Dtype>::Blob(const int num, const int height,
+    const int width, const int channels) {
+  Reshape(num, height, width, channels);
 }
 
 template <typename Dtype>
@@ -84,7 +84,7 @@ void Blob<Dtype>::Update() {
 
 template <typename Dtype>
 void Blob<Dtype>::FromProto(const BlobProto& proto) {
-  Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
+  Reshape(proto.num(), proto.height(), proto.width(), proto.channels());
   // copy data
   Dtype* data_vec = mutable_cpu_data();
   for (int i = 0; i < count_; ++i) {
@@ -99,9 +99,9 @@ void Blob<Dtype>::FromProto(const BlobProto& proto) {
 template <typename Dtype>
 void Blob<Dtype>::ToProto(BlobProto* proto) {
   proto->set_num(num_);
-  proto->set_channels(channels_);
   proto->set_height(height_);
   proto->set_width(width_);
+  proto->set_channels(channels_);
   proto->clear_data();
   proto->clear_diff();
   const Dtype* data_vec = cpu_data();
index 4c0bf0d..e3aad86 100644 (file)
@@ -13,15 +13,15 @@ class Blob {
   Blob()
        : num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
        diff_() {};
-  explicit Blob(const int num, const int channels, const int height,
-    const int width);
+  explicit Blob(const int num, const int height,
+    const int width, const int channels);
   virtual ~Blob() {};
-  void Reshape(const int num, const int channels, const int height,
-      const int width);
+  void Reshape(const int num, const int height,
+      const int width, const int channels);
   inline int num() { return num_; }
-  inline int channels() { return channels_; }
   inline int height() { return height_; }
   inline int width() { return width_; }
+  inline int channels() { return channels_; }
   inline int count() {return count_; }
   
   const Dtype* cpu_data();
@@ -39,9 +39,9 @@ class Blob {
   shared_ptr<SyncedMemory> data_;
   shared_ptr<SyncedMemory> diff_;
   int num_;
-  int channels_;
   int height_;
   int width_;
+  int channels_;
   int count_;
 };  // class Blob
 
index 080cb9a..cc10434 100644 (file)
 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
 
+#define INSTANTIATE_CLASS(classname) \
+  template class classname<float>; \
+  template class classname<double>
+
 namespace caffeine {
 
 // We will use the boost shared_ptr instead of the new C++11 one mainly
index 880e615..07f31da 100644 (file)
@@ -25,6 +25,11 @@ class Filler {
 };  // class Filler
 
 template <typename Dtype>
+class FillerFactory {
+
+};
+
+template <typename Dtype>
 class ConstantFiller : public Filler<Dtype> {
  public:
   ConstantFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
@@ -90,6 +95,24 @@ class GaussianFiller : public Filler<Dtype> {
   };
 };
 
+// A function to get a specific filler from the specification given in
+// FillerParameter. Ideally this would be replaced by a factory pattern,
+// but we will leave it this way for now.
+template <typename Dtype>
+Filler<Dtype>* GetFiller(const FillerParameter& param) {
+  const std::string& type = param.type();
+  if (type == "constant") {
+    return new ConstantFiller<Dtype>(param);
+  } else if (type == "uniform") {
+    return new UniformFiller<Dtype>(param);
+  } else if (type == "gaussian") {
+    return new GaussianFiller<Dtype>(param);
+  } else {
+    CHECK(false) << "Unknown filler name: " << param.type();
+  }
+  return (Filler<Dtype>*)(NULL);
+}
+
 }  // namespace caffeine
 
 #endif  // CAFFEINE_FILLER_HPP_
similarity index 97%
rename from src/caffeine/dropout_layer.cu
rename to src/caffeine/layers/dropout_layer.cu
index bfed41d..8dea15f 100644 (file)
@@ -6,7 +6,6 @@
 #include "caffeine/syncedmem.hpp"
 #include "caffeine/vision_layers.hpp"
 
-
 using std::max;
 
 namespace caffeine {
@@ -77,7 +76,6 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   const int count = bottom[0]->count();
   if (Caffeine::phase() == Caffeine::TRAIN) {
-    // Create random numbers
     CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
         (unsigned int*)(rand_vec_->mutable_gpu_data()), count));
     // set thresholds
@@ -117,8 +115,7 @@ Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   return Dtype(0);
 }
 
-template class DropoutLayer<float>;
-template class DropoutLayer<double>;
+INSTANTIATE_CLASS(DropoutLayer);
 
 
 }  // namespace caffeine
diff --git a/src/caffeine/layers/inner_product_layer.cu b/src/caffeine/layers/inner_product_layer.cu
new file mode 100644 (file)
index 0000000..fa40093
--- /dev/null
@@ -0,0 +1,156 @@
+#include <mkl.h>
+#include <cublas_v2.h>
+
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
+       const int num_output = this->layer_param_.num_output();
+       const bool gemm_last_dim = this->layer_param_.gemm_last_dim();
+       biasterm_ = this->layer_param_.biasterm();
+       // Figure out the dimensions
+       if (gemm_last_dim) {
+               M_ = bottom[0]->count() / bottom[0]->channels();
+       K_ = bottom[0]->channels();
+       N_ = num_output;
+       (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->height(),
+                       bottom[0]->width(), num_output);
+       } else {
+               M_ = bottom[0]->num();
+               K_ = bottom[0]->count() / bottom[0]->num();
+               N_ = num_output;
+               (*top)[0]->Reshape(bottom[0]->num(), 1, 1, num_output);
+       }
+       if (biasterm_) {
+               this->blobs_.resize(2);
+       } else {
+               this->blobs_.resize(1);
+       }
+       // Intialize the weight
+       Blob<Dtype>& weight = this->blobs_[0];
+       weight.Reshape(1, 1, K_, N_);
+       // fill the weights
+       shared_ptr<Filler<Dtype> > weight_filler(
+                       GetFiller<Dtype>(this->layer_param_.weight_filler()));
+       weight_filler->Fill(&weight);
+       // If necessary, intiialize and fill the bias term
+       if (biasterm_) {
+               Blob<Dtype>& bias = this->blobs_[1];
+               bias.Reshape(1, 1, 1, N_);
+               shared_ptr<Filler<Dtype> > bias_filler(
+                               GetFiller<Dtype>(this->layer_param_.bias_filler()));
+               bias_filler->Fill(&bias);
+       }
+};
+
+template <typename Dtype>
+void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  const Dtype* weight = this->blobs_[0].cpu_data();
+  const Dtype* bias = NULL;
+  if (biasterm_) {
+       bias = this->blobs_[1].cpu_data();
+  }
+  switch(sizeof(Dtype)) {
+  case sizeof(float):
+       // matrix multiply
+       cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
+                       1., (const float*)bottom_data, K_, (const float*)weight, N_, 0.,
+                       (float*)top_data, N_);
+       if (bias) {
+               // add bias
+               for (int i = 0; i < M_; ++i) {
+                       cblas_saxpy(N_, 1., (const float*)bias, 1,
+                                       (float*)(top_data) + (N_ * i), 1);
+               }
+       }
+  case sizeof(double):
+    // matrix multiply
+       cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
+                       1., (const double*)bottom_data, K_, (const double*)weight, N_, 0.,
+                       (double*)top_data, N_);
+       if (bias) {
+               // add bias
+               for (int i = 0; i < M_; ++i) {
+                       cblas_daxpy(N_, 1., (const double*)bias, 1,
+                                       (double*)(top_data) + (N_ * i), 1);
+               }
+       }
+  default:
+       CHECK(false) << "Unknown data type.";
+  }
+}
+
+template <typename Dtype>
+Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  CHECK(false);
+  return Dtype(0);
+}
+
+template <typename Dtype>
+__global__ void BroadcastCopy(const int total, const int vec_len,
+       const Dtype* in_vec, Dtype* out_matrix) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < total) {
+       int v_index = index % vec_len;
+       out_matrix[index] = in_vec[v_index];
+  }
+}
+
+template <typename Dtype>
+void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  const Dtype* weight = this->blobs_[0].gpu_data();
+  const Dtype* bias = NULL;
+  Dtype alpha = 1., beta = 0.;
+  if (biasterm_) {
+       bias = this->blobs_[1].gpu_data();
+       beta = 1.;
+       const int count = (*top)[0]->count();
+       // we pre-copy the bias to the results, and then call gemm.
+       BroadcastCopy<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+                       count, N_, bias, top_data);
+  }
+  switch(sizeof(Dtype)) {
+  case sizeof(float):
+       // matrix multiply: since cublas uses Fortran major, we actually do
+       // C' = B' A'
+       CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
+                       CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
+                       (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
+  case sizeof(double):
+    // matrix multiply
+       CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
+                       CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
+                       (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
+  default:
+       CHECK(false) << "Unknown data type.";
+  }
+}
+
+template <typename Dtype>
+Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  CHECK(false);
+  return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(InnerProductLayer);
+
+}  // namespace caffeine
similarity index 87%
rename from src/caffeine/neuron_layer.cpp
rename to src/caffeine/layers/neuron_layer.cpp
index 050c690..4cac434 100644 (file)
@@ -12,7 +12,6 @@ void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       bottom[0]->height(), bottom[0]->width());
 };
 
-template class NeuronLayer<float>;
-template class NeuronLayer<double>;
+INSTANTIATE_CLASS(NeuronLayer);
 
 }  // namespace caffeine
similarity index 97%
rename from src/caffeine/relu_layer.cu
rename to src/caffeine/layers/relu_layer.cu
index fb95b04..12a9b6c 100644 (file)
@@ -75,8 +75,7 @@ Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   return Dtype(0);
 }
 
-template class ReLULayer<float>;
-template class ReLULayer<double>;
+INSTANTIATE_CLASS(ReLULayer);
 
 
 }  // namespace caffeine
index 7bb3708..58dbe93 100644 (file)
@@ -19,10 +19,14 @@ message LayerParameter {
 
   optional float alpha = 13 [default = 1.]; // for local response norm
   optional float beta = 14 [default = 0.75]; // for local response norm
+
+  // for innerproduct: if true, carry out inner product computation on the
+  // last dim only
+  optional bool gemm_last_dim = 15 [ default = false];
 }
 
 message FillerParameter {
-  required string type = 1;
+  required string type = 1 [default = 'constant'];
   optional float value = 2 [default = 0]; // the value in constant filler
   optional float min = 3 [default = 0]; // the min value in uniform filler
   optional float max = 4 [default = 1]; // the max value in uniform filler
index f5ebc63..ba56422 100644 (file)
@@ -25,9 +25,9 @@ TYPED_TEST(BlobSimpleTest, TestInitialization) {
   EXPECT_TRUE(this->blob_);
   EXPECT_TRUE(this->blob_preshaped_);
   EXPECT_EQ(this->blob_preshaped_->num(), 2);
-  EXPECT_EQ(this->blob_preshaped_->channels(), 3);
-  EXPECT_EQ(this->blob_preshaped_->height(), 4);
-  EXPECT_EQ(this->blob_preshaped_->width(), 5);
+  EXPECT_EQ(this->blob_preshaped_->height(), 3);
+  EXPECT_EQ(this->blob_preshaped_->width(), 4);
+  EXPECT_EQ(this->blob_preshaped_->channels(), 5);
   EXPECT_EQ(this->blob_preshaped_->count(), 120);
   EXPECT_EQ(this->blob_->num(), 0);
   EXPECT_EQ(this->blob_->channels(), 0);
@@ -46,9 +46,9 @@ TYPED_TEST(BlobSimpleTest, TestPointers) {
 TYPED_TEST(BlobSimpleTest, TestReshape) {
   this->blob_->Reshape(2, 3, 4, 5);
   EXPECT_EQ(this->blob_->num(), 2);
-  EXPECT_EQ(this->blob_->channels(), 3);
-  EXPECT_EQ(this->blob_->height(), 4);
-  EXPECT_EQ(this->blob_->width(), 5);
+  EXPECT_EQ(this->blob_->height(), 3);
+  EXPECT_EQ(this->blob_->width(), 4);
+  EXPECT_EQ(this->blob_->channels(), 5);
   EXPECT_EQ(this->blob_->count(), 120);
 }
 
index 92a50a5..d64a014 100644 (file)
@@ -14,7 +14,7 @@ class NeuronLayerTest : public ::testing::Test {
  protected:
   NeuronLayerTest()
       : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
-        blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
+        blob_top_(new Blob<Dtype>()) {
     // fill the values
     FillerParameter filler_param;
     GaussianFiller<Dtype> filler(filler_param);
@@ -36,6 +36,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUCPU) {
   LayerParameter layer_param;
   Caffeine::set_mode(Caffeine::CPU);
   ReLULayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Now, check values
   const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
@@ -50,6 +51,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) {
   LayerParameter layer_param;
   Caffeine::set_mode(Caffeine::GPU);
   ReLULayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Now, check values
   const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
@@ -60,4 +62,76 @@ TYPED_TEST(NeuronLayerTest, TestReLUGPU) {
   }
 }
 
+TYPED_TEST(NeuronLayerTest, TestDropoutCPU) {
+  LayerParameter layer_param;
+  Caffeine::set_mode(Caffeine::CPU);
+  Caffeine::set_phase(Caffeine::TRAIN);
+  DropoutLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  float scale = 1. / (1. - layer_param.dropout_ratio());
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    if (top_data[i] != 0) {
+      EXPECT_EQ(top_data[i], bottom_data[i] * scale);
+    }
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestDropoutCPUTestPhase) {
+  LayerParameter layer_param;
+  Caffeine::set_mode(Caffeine::CPU);
+  Caffeine::set_phase(Caffeine::TEST);
+  DropoutLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  float scale = 1. / (1. - layer_param.dropout_ratio());
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    if (top_data[i] != 0) {
+      EXPECT_EQ(top_data[i], bottom_data[i]);
+    }
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestDropoutGPU) {
+  LayerParameter layer_param;
+  Caffeine::set_mode(Caffeine::GPU);
+  Caffeine::set_phase(Caffeine::TRAIN);
+  DropoutLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  float scale = 1. / (1. - layer_param.dropout_ratio());
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    if (top_data[i] != 0) {
+      EXPECT_EQ(top_data[i], bottom_data[i] * scale);
+    }
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestDropoutGPUTestPhase) {
+  LayerParameter layer_param;
+  Caffeine::set_mode(Caffeine::GPU);
+  Caffeine::set_phase(Caffeine::TEST);
+  DropoutLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  float scale = 1. / (1. - layer_param.dropout_ratio());
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    if (top_data[i] != 0) {
+      EXPECT_EQ(top_data[i], bottom_data[i]);
+    }
+  }
+}
+
 }
index 432da7e..8cf361c 100644 (file)
@@ -5,6 +5,8 @@
 
 namespace caffeine {
 
+// The neuron layer is a specific type of layers that just works on single
+// celements.
 template <typename Dtype>
 class NeuronLayer : public Layer<Dtype> {
  public:
@@ -14,6 +16,7 @@ class NeuronLayer : public Layer<Dtype> {
       vector<Blob<Dtype>*>* top);
 };
 
+
 template <typename Dtype>
 class ReLULayer : public NeuronLayer<Dtype> {
  public:
@@ -31,6 +34,7 @@ class ReLULayer : public NeuronLayer<Dtype> {
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+
 template <typename Dtype>
 class DropoutLayer : public NeuronLayer<Dtype> {
  public:
@@ -55,8 +59,28 @@ class DropoutLayer : public NeuronLayer<Dtype> {
 };
 
 
+template <typename Dtype>
+class InnerProductLayer : public Layer<Dtype> {
+ public:
+  explicit InnerProductLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
-
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  int M_;
+  int K_;
+  int N_;
+  bool biasterm_;
+};
 
 }  // namespace caffeine