rewrite ConvolutionLayer to use BaseConvolutionLayer helpers
authorJonathan L Long <jonlong@cs.berkeley.edu>
Mon, 22 Dec 2014 03:42:29 +0000 (19:42 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Sun, 11 Jan 2015 08:28:44 +0000 (00:28 -0800)
include/caffe/vision_layers.hpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/conv_layer.cu

index 4cf7e8b..4d93e6c 100644 (file)
@@ -124,7 +124,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
  *   the output channel N' columns of the output matrix.
  */
 template <typename Dtype>
-class ConvolutionLayer : public Layer<Dtype> {
+class ConvolutionLayer : public BaseConvolutionLayer<Dtype> {
  public:
   /**
    * @param param provides ConvolutionParameter convolution_param,
@@ -155,18 +155,10 @@ class ConvolutionLayer : public Layer<Dtype> {
    *    kernels + stream parallelism) engines.
    */
   explicit ConvolutionLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top);
-  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top);
-
+      : BaseConvolutionLayer<Dtype>(param) {}
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_CONVOLUTION;
   }
-  virtual inline int MinBottomBlobs() const { return 1; }
-  virtual inline int MinTopBlobs() const { return 1; }
-  virtual inline bool EqualNumBottomTopBlobs() const { return true; }
 
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
@@ -177,30 +169,10 @@ class ConvolutionLayer : public Layer<Dtype> {
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual inline bool reverse_dimensions() { return false; }
+  virtual void compute_output_shape();
 
-  int kernel_h_, kernel_w_;
-  int stride_h_, stride_w_;
-  int num_;
-  int channels_;
-  int pad_h_, pad_w_;
-  int height_, width_;
-  int group_;
-  int num_output_;
-  int height_out_, width_out_;
-  bool bias_term_;
-  bool is_1x1_;
 
-  /// M_ is the channel dimension of the output for a single group, which is the
-  /// leading dimension of the filter matrix.
-  int M_;
-  /// K_ is the dimension of an unrolled input for a single group, which is the
-  /// leading dimension of the data matrix.
-  int K_;
-  /// N_ is the spatial dimension of the output, the H x W, which are the last
-  /// dimensions of the data and filter matrices.
-  int N_;
-  Blob<Dtype> col_buffer_;
-  Blob<Dtype> bias_multiplier_;
 };
 
 #ifdef USE_CUDNN
index 0a03202..9fd2fc6 100644 (file)
 namespace caffe {
 
 template <typename Dtype>
-void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top) {
-  // Configure the kernel size, padding, stride, and inputs.
-  ConvolutionParameter conv_param = this->layer_param_.convolution_param();
-  CHECK(!conv_param.has_kernel_size() !=
-      !(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "Filter size is kernel_size OR kernel_h and kernel_w; not both";
-  CHECK(conv_param.has_kernel_size() ||
-      (conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "For non-square filters both kernel_h and kernel_w are required.";
-  CHECK((!conv_param.has_pad() && conv_param.has_pad_h()
-      && conv_param.has_pad_w())
-      || (!conv_param.has_pad_h() && !conv_param.has_pad_w()))
-      << "pad is pad OR pad_h and pad_w are required.";
-  CHECK((!conv_param.has_stride() && conv_param.has_stride_h()
-      && conv_param.has_stride_w())
-      || (!conv_param.has_stride_h() && !conv_param.has_stride_w()))
-      << "Stride is stride OR stride_h and stride_w are required.";
-  if (conv_param.has_kernel_size()) {
-    kernel_h_ = kernel_w_ = conv_param.kernel_size();
-  } else {
-    kernel_h_ = conv_param.kernel_h();
-    kernel_w_ = conv_param.kernel_w();
-  }
-  CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-  CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
-  if (!conv_param.has_pad_h()) {
-    pad_h_ = pad_w_ = conv_param.pad();
-  } else {
-    pad_h_ = conv_param.pad_h();
-    pad_w_ = conv_param.pad_w();
-  }
-  if (!conv_param.has_stride_h()) {
-    stride_h_ = stride_w_ = conv_param.stride();
-  } else {
-    stride_h_ = conv_param.stride_h();
-    stride_w_ = conv_param.stride_w();
-  }
-  // Special case: im2col is the identity for 1x1 convolution with stride 1
-  // and no padding, so flag for skipping the buffer and transformation.
-  is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1
-      && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0;
-  // Configure output channels and groups.
-  channels_ = bottom[0]->channels();
-  num_output_ = this->layer_param_.convolution_param().num_output();
-  CHECK_GT(num_output_, 0);
-  group_ = this->layer_param_.convolution_param().group();
-  CHECK_EQ(channels_ % group_, 0);
-  CHECK_EQ(num_output_ % group_, 0)
-      << "Number of output should be multiples of group.";
-  // Handle the parameters: weights and biases.
-  // - blobs_[0] holds the filter weights
-  // - blobs_[1] holds the biases (optional)
-  bias_term_ = this->layer_param_.convolution_param().bias_term();
-  if (this->blobs_.size() > 0) {
-    LOG(INFO) << "Skipping parameter initialization";
-  } else {
-    if (bias_term_) {
-      this->blobs_.resize(2);
-    } else {
-      this->blobs_.resize(1);
-    }
-    // Initialize and fill the weights:
-    // output channels x input channels per-group x kernel height x kernel width
-    this->blobs_[0].reset(new Blob<Dtype>(
-        num_output_, channels_ / group_, kernel_h_, kernel_w_));
-    shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
-        this->layer_param_.convolution_param().weight_filler()));
-    weight_filler->Fill(this->blobs_[0].get());
-    // If necessary, initialize and fill the biases:
-    // 1 x 1 x 1 x output channels
-    if (bias_term_) {
-      this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, num_output_));
-      shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
-          this->layer_param_.convolution_param().bias_filler()));
-      bias_filler->Fill(this->blobs_[1].get());
-    }
-  }
-  // Propagate gradients to the parameters (as directed by backward pass).
-  this->param_propagate_down_.resize(this->blobs_.size(), true);
-}
-
-template <typename Dtype>
-void ConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
-      const vector<Blob<Dtype>*>& top) {
-  num_ = bottom[0]->num();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
-  CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with"
-    " convolution kernel.";
-  // TODO: generalize to handle inputs of different shapes.
-  for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
-    CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num.";
-    CHECK_EQ(channels_, bottom[bottom_id]->channels())
-        << "Inputs must have same channels.";
-    CHECK_EQ(height_, bottom[bottom_id]->height())
-        << "Inputs must have same height.";
-    CHECK_EQ(width_, bottom[bottom_id]->width())
-        << "Inputs must have same width.";
-  }
-  // Shape the tops.
-  height_out_ =
-      (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
-  width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
-  for (int top_id = 0; top_id < top.size(); ++top_id) {
-    top[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
-  }
-  // Prepare the matrix multiplication computation.
-  // Each input will be convolved as a single GEMM.
-  M_ = num_output_ / group_;
-  K_ = channels_ * kernel_h_ * kernel_w_ / group_;
-  N_ = height_out_ * width_out_;
-  // The im2col result buffer will only hold one image at a time to avoid
-  // overly large memory usage. In the special case of 1x1 convolution
-  // it goes lazily unused to save memory.
-  col_buffer_.Reshape(
-      1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
-  // Set up the all ones "bias multiplier" for adding biases by BLAS
-  if (bias_term_) {
-    bias_multiplier_.Reshape(1, 1, 1, N_);
-    caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data());
-  }
+void ConvolutionLayer<Dtype>::compute_output_shape() {
+  this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_)
+      / this->stride_h_ + 1;
+  this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_)
+      / this->stride_w_ + 1;
 }
 
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
+  const Dtype* weight = this->blobs_[0]->cpu_data();
   for (int i = 0; i < bottom.size(); ++i) {
     const Dtype* bottom_data = bottom[i]->cpu_data();
     Dtype* top_data = top[i]->mutable_cpu_data();
-    Dtype* col_buff = NULL;
-    if (!is_1x1_) {
-      col_buff = col_buffer_.mutable_cpu_data();
-    }
-    const Dtype* weight = this->blobs_[0]->cpu_data();
-    int weight_offset = M_ * K_;  // number of filter parameters in a group
-    int col_offset = K_ * N_;  // number of values in an input region / column
-    int top_offset = M_ * N_;  // number of values in an output region / column
-    for (int n = 0; n < num_; ++n) {
-      // im2col transformation: unroll input regions for filtering
-      // into column matrix for multplication.
-      if (!is_1x1_) {
-        im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_,
-            width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
-            col_buff);
-      } else {  // special case for 1x1 convolution
-        col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n);
-      }
-      // Take inner products for groups.
-      for (int g = 0; g < group_; ++g) {
-        caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
-          (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g,
-          (Dtype)0., top_data + top[i]->offset(n) + top_offset * g);
-      }
-      // Add bias.
-      if (bias_term_) {
-        caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
-            N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(),
-            bias_multiplier_.cpu_data(),
-            (Dtype)1., top_data + top[i]->offset(n));
+    for (int n = 0; n < this->num_; ++n) {
+      this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight,
+          top_data + top[i]->offset(n));
+      if (this->bias_term_) {
+        const Dtype* bias = this->blobs_[1]->cpu_data();
+        this->forward_cpu_bias(top_data + top[i]->offset(n), bias);
       }
     }
   }
@@ -177,82 +37,37 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
-  const Dtype* weight = NULL;
-  Dtype* weight_diff = NULL;
+  const Dtype* weight = this->blobs_[0]->cpu_data();
+  Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
   if (this->param_propagate_down_[0]) {
-    weight = this->blobs_[0]->cpu_data();
-    weight_diff = this->blobs_[0]->mutable_cpu_diff();
     caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
   }
-  Dtype* bias_diff = NULL;
-  if (bias_term_ && this->param_propagate_down_[1]) {
-    bias_diff = this->blobs_[1]->mutable_cpu_diff();
-    caffe_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
+  if (this->bias_term_ && this->param_propagate_down_[1]) {
+    caffe_set(this->blobs_[1]->count(), Dtype(0),
+        this->blobs_[1]->mutable_cpu_diff());
   }
-  const int weight_offset = M_ * K_;
-  const int col_offset = K_ * N_;
-  const int top_offset = M_ * N_;
   for (int i = 0; i < top.size(); ++i) {
-    const Dtype* top_diff = NULL;
+    const Dtype* top_diff = top[i]->cpu_diff();
+    const Dtype* bottom_data = bottom[i]->cpu_data();
+    Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
     // Bias gradient, if necessary.
-    if (bias_term_ && this->param_propagate_down_[1]) {
-      top_diff = top[i]->cpu_diff();
-      for (int n = 0; n < num_; ++n) {
-        caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, N_,
-            1., top_diff + top[0]->offset(n),
-            bias_multiplier_.cpu_data(), 1.,
-            bias_diff);
+    if (this->bias_term_ && this->param_propagate_down_[1]) {
+      Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
+      for (int n = 0; n < this->num_; ++n) {
+        this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n));
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
-      if (!top_diff) {
-        top_diff = top[i]->cpu_diff();
-      }
-      Dtype* col_buff = NULL;
-      if (!is_1x1_) {
-        col_buff = col_buffer_.mutable_cpu_data();
-      }
-      const Dtype* bottom_data = bottom[i]->cpu_data();
-      Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
-      for (int n = 0; n < num_; ++n) {
-        // Since we saved memory in the forward pass by not storing all col
-        // data, we will need to recompute them.
-        if (!is_1x1_) {
-          im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_,
-                    width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
-                    stride_h_, stride_w_, col_buff);
-        } else {
-          col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n);
-        }
+      for (int n = 0; n < this->num_; ++n) {
         // gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          for (int g = 0; g < group_; ++g) {
-            caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
-                (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g,
-                col_buff + col_offset * g, (Dtype)1.,
-                weight_diff + weight_offset * g);
-          }
+          this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n),
+              top_diff + top[i]->offset(n), weight_diff);
         }
         // gradient w.r.t. bottom data, if necessary.
         if (propagate_down[i]) {
-          if (weight == NULL) {
-            weight = this->blobs_[0]->cpu_data();
-          }
-          if (is_1x1_) {
-            col_buff = bottom[i]->mutable_cpu_diff() + bottom[i]->offset(n);
-          }
-          for (int g = 0; g < group_; ++g) {
-            caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
-                (Dtype)1., weight + weight_offset * g,
-                top_diff + top[i]->offset(n) + top_offset * g,
-                (Dtype)0., col_buff + col_offset * g);
-          }
-          // col2im back to the data
-          if (!is_1x1_) {
-            col2im_cpu(col_buff, channels_, height_, width_,
-                kernel_h_, kernel_w_, pad_h_, pad_w_,
-                stride_h_, stride_w_, bottom_diff + bottom[i]->offset(n));
-          }
+          this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight,
+              bottom_diff + bottom[i]->offset(n));
         }
       }
     }
index af14fac..3902fdf 100644 (file)
 
 namespace caffe {
 
-/// @brief refer to CPU forward -- the BLAS implementation is the same.
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
+  const Dtype* weight = this->blobs_[0]->gpu_data();
   for (int i = 0; i < bottom.size(); ++i) {
     const Dtype* bottom_data = bottom[i]->gpu_data();
     Dtype* top_data = top[i]->mutable_gpu_data();
-    Dtype* col_buff = NULL;
-    if (!is_1x1_) {
-      col_buff = col_buffer_.mutable_gpu_data();
-    }
-    const Dtype* weight = this->blobs_[0]->gpu_data();
-    int weight_offset = M_ * K_;
-    int col_offset = K_ * N_;
-    int top_offset = M_ * N_;
-    for (int n = 0; n < num_; ++n) {
-      // im2col transformation: unroll input regions for filtering
-      // into column matrix for multplication.
-      if (!is_1x1_) {
-        im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
-            width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
-            col_buff);
-      } else {
-        col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n);
-      }
-      // Take inner products for groups.
-      for (int g = 0; g < group_; ++g) {
-        caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
-          (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g,
-          (Dtype)0., top_data + top[i]->offset(n) + top_offset * g);
-      }
-      // Add bias.
-      if (bias_term_) {
-        caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
-            N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
-            bias_multiplier_.gpu_data(),
-            (Dtype)1., top_data + top[i]->offset(n));
+    for (int n = 0; n < this->num_; ++n) {
+      this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight,
+          top_data + top[i]->offset(n));
+      if (this->bias_term_) {
+        const Dtype* bias = this->blobs_[1]->gpu_data();
+        this->forward_gpu_bias(top_data + top[i]->offset(n), bias);
       }
     }
   }
 }
 
-/// @brief refer to CPU backward -- the BLAS implementation is the same.
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
-  const Dtype* weight = NULL;
-  Dtype* weight_diff = NULL;
+  const Dtype* weight = this->blobs_[0]->gpu_data();
+  Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
   if (this->param_propagate_down_[0]) {
-    weight = this->blobs_[0]->gpu_data();
-    weight_diff = this->blobs_[0]->mutable_gpu_diff();
     caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
   }
-  Dtype* bias_diff = NULL;
-  if (bias_term_ && this->param_propagate_down_[1]) {
-    bias_diff = this->blobs_[1]->mutable_gpu_diff();
-    caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
+  if (this->bias_term_ && this->param_propagate_down_[1]) {
+    caffe_gpu_set(this->blobs_[1]->count(), Dtype(0),
+        this->blobs_[1]->mutable_gpu_diff());
   }
-  const int weight_offset = M_ * K_;
-  const int col_offset = K_ * N_;
-  const int top_offset = M_ * N_;
   for (int i = 0; i < top.size(); ++i) {
-    const Dtype* top_diff = NULL;
+    const Dtype* top_diff = top[i]->gpu_diff();
     // Bias gradient, if necessary.
-    if (bias_term_ && this->param_propagate_down_[1]) {
-      top_diff = top[i]->gpu_diff();
-      for (int n = 0; n < num_; ++n) {
-        caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, N_,
-            1., top_diff + top[0]->offset(n),
-            bias_multiplier_.gpu_data(), 1.,
-            bias_diff);
+    if (this->bias_term_ && this->param_propagate_down_[1]) {
+      Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
+      for (int n = 0; n < this->num_; ++n) {
+        this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n));
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
-      if (!top_diff) {
-        top_diff = top[i]->gpu_diff();
-      }
-      Dtype* col_buff = NULL;
-      if (!is_1x1_) {
-        col_buff = col_buffer_.mutable_gpu_data();
-      }
       const Dtype* bottom_data = bottom[i]->gpu_data();
       Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
-      for (int n = 0; n < num_; ++n) {
-        // Since we saved memory in the forward pass by not storing all col
-        // data, we will need to recompute them.
-        if (!is_1x1_) {
-          im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
-                    width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
-                    stride_h_, stride_w_, col_buff);
-        } else {
-          col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n);
-        }
+      for (int n = 0; n < this->num_; ++n) {
         // gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          for (int g = 0; g < group_; ++g) {
-            caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
-                (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g,
-                col_buff + col_offset * g, (Dtype)1.,
-                weight_diff + weight_offset * g);
-          }
+          this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n),
+              top_diff + top[i]->offset(n), weight_diff);
         }
-        // gradient w.r.t. bottom data, if necessary
+        // gradient w.r.t. bottom data, if necessary.
         if (propagate_down[i]) {
-          if (weight == NULL) {
-            weight = this->blobs_[0]->gpu_data();
-          }
-          if (is_1x1_) {
-            col_buff = bottom[i]->mutable_gpu_diff() + bottom[i]->offset(n);
-          }
-          for (int g = 0; g < group_; ++g) {
-            caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
-                (Dtype)1., weight + weight_offset * g,
-                top_diff + top[i]->offset(n) + top_offset * g,
-                (Dtype)0., col_buff + col_offset * g);
-          }
-          // col2im back to the data
-          if (!is_1x1_) {
-            col2im_gpu(col_buff, channels_, height_, width_,
-                kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
-                bottom_diff + bottom[i]->offset(n));
-          }
+          this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight,
+              bottom_diff + bottom[i]->offset(n));
         }
       }
     }
   }
 }
 
-
 INSTANTIATE_LAYER_GPU_FUNCS(ConvolutionLayer);
 
 }  // namespace caffe