Im2col and Convolution layers support N spatial axes
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 5 Mar 2015 03:30:17 +0000 (19:30 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 19 Sep 2015 00:53:19 +0000 (17:53 -0700)
19 files changed:
include/caffe/util/im2col.hpp
include/caffe/vision_layers.hpp
src/caffe/layers/base_conv_layer.cpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/conv_layer.cu
src/caffe/layers/cudnn_conv_layer.cpp
src/caffe/layers/cudnn_conv_layer.cu
src/caffe/layers/deconv_layer.cpp
src/caffe/layers/deconv_layer.cu
src/caffe/layers/im2col_layer.cpp
src/caffe/layers/im2col_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_convolution_layer.cpp
src/caffe/test/test_deconvolution_layer.cpp
src/caffe/test/test_im2col_kernel.cu
src/caffe/test/test_im2col_layer.cpp
src/caffe/util/im2col.cpp
src/caffe/util/im2col.cu
src/caffe/util/upgrade_proto.cpp

index 0051e2f..531fd29 100644 (file)
@@ -4,24 +4,48 @@
 namespace caffe {
 
 template <typename Dtype>
+void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col);
+
+template <typename Dtype>
 void im2col_cpu(const Dtype* data_im, const int channels,
     const int height, const int width, const int kernel_h, const int kernel_w,
     const int pad_h, const int pad_w, const int stride_h,
     const int stride_w, Dtype* data_col);
 
 template <typename Dtype>
+void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_im);
+
+template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int patch_h, const int patch_w,
     const int pad_h, const int pad_w, const int stride_h,
     const int stride_w, Dtype* data_im);
 
 template <typename Dtype>
+void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
+    const int col_size, const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col);
+
+template <typename Dtype>
 void im2col_gpu(const Dtype* data_im, const int channels,
     const int height, const int width, const int kernel_h, const int kernel_w,
     const int pad_h, const int pad_w, const int stride_h,
     const int stride_w, Dtype* data_col);
 
 template <typename Dtype>
+void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
+    const int im_size, const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_im);
+
+template <typename Dtype>
 void col2im_gpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int patch_h, const int patch_w,
     const int pad_h, const int pad_w, const int stride_h,
index 211e3d9..eae6582 100644 (file)
@@ -64,46 +64,101 @@ class BaseConvolutionLayer : public Layer<Dtype> {
   // Compute height_out_ and width_out_ from other parameters.
   virtual void compute_output_shape() = 0;
 
-  int kernel_h_, kernel_w_;
-  int stride_h_, stride_w_;
+  /// @brief The spatial dimensions of a filter kernel.
+  Blob<int> kernel_shape_;
+  /// @brief The spatial dimensions of the stride.
+  Blob<int> stride_;
+  /// @brief The spatial dimensions of the padding.
+  Blob<int> pad_;
+  /// @brief The spatial dimensions of the convolution input.
+  Blob<int> conv_input_shape_;
+  /// @brief The spatial dimensions of the input.
+  Blob<int> input_shape_;
+  /// @brief The spatial dimensions of the col_buffer.
+  vector<int> col_buffer_shape_;
+  /// @brief The spatial dimensions of the output.
+  vector<int> output_shape_;
+
+  int num_spatial_axes_;
+  int bottom_dim_;
+  int top_dim_;
+
+  int channel_axis_;
   int num_;
   int channels_;
-  int pad_h_, pad_w_;
-  int height_, width_;
   int group_;
+  int out_spatial_dim_;
+  int weight_offset_;
   int num_output_;
-  int height_out_, width_out_;
   bool bias_term_;
   bool is_1x1_;
+  bool force_nd_im2col_;
 
  private:
   // wrap im2col/col2im so we don't have to remember the (long) argument lists
   inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) {
-    im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      im2col_cpu(data, conv_in_channels_,
+          conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+    } else {
+      im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
+          col_buffer_shape_.data(), kernel_shape_.cpu_data(),
+          pad_.cpu_data(), stride_.cpu_data(), col_buff);
+    }
   }
   inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) {
-    col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      col2im_cpu(col_buff, conv_in_channels_,
+          conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+    } else {
+      col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
+          col_buffer_shape_.data(), kernel_shape_.cpu_data(),
+          pad_.cpu_data(), stride_.cpu_data(), data);
+    }
   }
 #ifndef CPU_ONLY
   inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) {
-    im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      im2col_gpu(data, conv_in_channels_,
+          conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+    } else {
+      im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_,
+          conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
+          kernel_shape_.gpu_data(), pad_.gpu_data(),
+          stride_.gpu_data(), col_buff);
+    }
   }
   inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) {
-    col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      col2im_gpu(col_buff, conv_in_channels_,
+          conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+    } else {
+      col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_,
+          conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
+          kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+          data);
+    }
   }
 #endif
 
+  int num_kernels_im2col_;
+  int num_kernels_col2im_;
   int conv_out_channels_;
   int conv_in_channels_;
   int conv_out_spatial_dim_;
-  int conv_in_height_;
-  int conv_in_width_;
   int kernel_dim_;
-  int weight_offset_;
   int col_offset_;
   int output_offset_;
 
@@ -250,7 +305,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
   cudnnTensorDescriptor_t    bias_desc_;
   cudnnFilterDescriptor_t      filter_desc_;
   vector<cudnnConvolutionDescriptor_t> conv_descs_;
-  int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
+  int bottom_offset_, top_offset_, bias_offset_;
   size_t workspaceSizeInBytes;
   void *workspace;
 };
@@ -287,11 +342,22 @@ class Im2colLayer : public Layer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
-  int kernel_h_, kernel_w_;
-  int stride_h_, stride_w_;
+  /// @brief The spatial dimensions of a filter kernel.
+  Blob<int> kernel_shape_;
+  /// @brief The spatial dimensions of the stride.
+  Blob<int> stride_;
+  /// @brief The spatial dimensions of the padding.
+  Blob<int> pad_;
+
+  int num_spatial_axes_;
+  int bottom_dim_;
+  int top_dim_;
+
+  int channel_axis_;
+  int num_;
   int channels_;
-  int height_, width_;
-  int pad_h_, pad_w_;
+
+  bool force_nd_im2col_;
 };
 
 // Forward declare PoolingLayer and SplitLayer for use in LRNLayer.
index ccb3adc..a5b90a5 100644 (file)
@@ -1,3 +1,4 @@
+#include <algorithm>
 #include <vector>
 
 #include "caffe/filler.hpp"
@@ -11,50 +12,103 @@ namespace caffe {
 template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, "
-      << "corresponding to (num, channels, height, width)";
   // Configure the kernel size, padding, stride, and inputs.
   ConvolutionParameter conv_param = this->layer_param_.convolution_param();
-  CHECK(!conv_param.has_kernel_size() !=
-      !(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "Filter size is kernel_size OR kernel_h and kernel_w; not both";
-  CHECK(conv_param.has_kernel_size() ||
-      (conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "For non-square filters both kernel_h and kernel_w are required.";
-  CHECK((!conv_param.has_pad() && conv_param.has_pad_h()
-      && conv_param.has_pad_w())
-      || (!conv_param.has_pad_h() && !conv_param.has_pad_w()))
-      << "pad is pad OR pad_h and pad_w are required.";
-  CHECK((!conv_param.has_stride() && conv_param.has_stride_h()
-      && conv_param.has_stride_w())
-      || (!conv_param.has_stride_h() && !conv_param.has_stride_w()))
-      << "Stride is stride OR stride_h and stride_w are required.";
-  if (conv_param.has_kernel_size()) {
-    kernel_h_ = kernel_w_ = conv_param.kernel_size();
+  force_nd_im2col_ = conv_param.force_nd_im2col();
+  channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis());
+  const int first_spatial_axis = channel_axis_ + 1;
+  const int num_axes = bottom[0]->num_axes();
+  num_spatial_axes_ = num_axes - first_spatial_axis;
+  CHECK_GE(num_spatial_axes_, 0);
+  // Setup input dimensions (input_shape_).
+  vector<int> bottom_dim_blob_shape(1, num_spatial_axes_ + 1);
+  input_shape_.Reshape(bottom_dim_blob_shape);
+  int* input_shape_data = input_shape_.mutable_cpu_data();
+  for (int i = 0; i < num_spatial_axes_ + 1; ++i) {
+    input_shape_data[i] = bottom[0]->shape(channel_axis_ + i);
+  }
+  vector<int> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, 1));
+  // Setup filter kernel dimensions (kernel_shape_).
+  kernel_shape_.Reshape(spatial_dim_blob_shape);
+  int* kernel_shape_data = kernel_shape_.mutable_cpu_data();
+  if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "kernel_h & kernel_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.kernel_size_size())
+        << "Either kernel_size or kernel_h/w should be specified; not both.";
+    kernel_shape_data[0] = conv_param.kernel_h();
+    kernel_shape_data[1] = conv_param.kernel_w();
   } else {
-    kernel_h_ = conv_param.kernel_h();
-    kernel_w_ = conv_param.kernel_w();
+    const int num_kernel_dims = conv_param.kernel_size_size();
+    CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_)
+        << "kernel_size must be specified once, or once per spatial dimension "
+        << "(kernel_size specified " << num_kernel_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+      for (int i = 0; i < num_spatial_axes_; ++i) {
+        kernel_shape_data[i] =
+            conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i);
+      }
+  }
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero.";
   }
-  CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-  CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
-  if (!conv_param.has_pad_h()) {
-    pad_h_ = pad_w_ = conv_param.pad();
+  // Setup stride dimensions (stride_).
+  stride_.Reshape(spatial_dim_blob_shape);
+  int* stride_data = stride_.mutable_cpu_data();
+  if (conv_param.has_stride_h() || conv_param.has_stride_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "stride_h & stride_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.stride_size())
+        << "Either stride or stride_h/w should be specified; not both.";
+    stride_data[0] = conv_param.stride_h();
+    stride_data[1] = conv_param.stride_w();
   } else {
-    pad_h_ = conv_param.pad_h();
-    pad_w_ = conv_param.pad_w();
+    const int num_stride_dims = conv_param.stride_size();
+    CHECK(num_stride_dims == 0 || num_stride_dims == 1 ||
+          num_stride_dims == num_spatial_axes_)
+        << "stride must be specified once, or once per spatial dimension "
+        << "(stride specified " << num_stride_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+    const int kDefaultStride = 1;
+    for (int i = 0; i < num_spatial_axes_; ++i) {
+      stride_data[i] = (num_stride_dims == 0) ? kDefaultStride :
+          conv_param.stride((num_stride_dims == 1) ? 0 : i);
+      CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero.";
+    }
   }
-  if (!conv_param.has_stride_h()) {
-    stride_h_ = stride_w_ = conv_param.stride();
+  // Setup pad dimensions (pad_).
+  pad_.Reshape(spatial_dim_blob_shape);
+  int* pad_data = pad_.mutable_cpu_data();
+  if (conv_param.has_pad_h() || conv_param.has_pad_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "pad_h & pad_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.pad_size())
+        << "Either pad or pad_h/w should be specified; not both.";
+    pad_data[0] = conv_param.pad_h();
+    pad_data[1] = conv_param.pad_w();
   } else {
-    stride_h_ = conv_param.stride_h();
-    stride_w_ = conv_param.stride_w();
+    const int num_pad_dims = conv_param.pad_size();
+    CHECK(num_pad_dims == 0 || num_pad_dims == 1 ||
+          num_pad_dims == num_spatial_axes_)
+        << "pad must be specified once, or once per spatial dimension "
+        << "(pad specified " << num_pad_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+    const int kDefaultPad = 0;
+    for (int i = 0; i < num_spatial_axes_; ++i) {
+      pad_data[i] = (num_pad_dims == 0) ? kDefaultPad :
+          conv_param.pad((num_pad_dims == 1) ? 0 : i);
+    }
   }
   // Special case: im2col is the identity for 1x1 convolution with stride 1
   // and no padding, so flag for skipping the buffer and transformation.
-  is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1
-      && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0;
+  is_1x1_ = true;
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    is_1x1_ &=
+        kernel_shape_data[i] == 1 && stride_data[i] == 1 && pad_data[i] == 0;
+    if (!is_1x1_) { break; }
+  }
   // Configure output channels and groups.
-  channels_ = bottom[0]->channels();
+  channels_ = bottom[0]->shape(channel_axis_);
   num_output_ = this->layer_param_.convolution_param().num_output();
   CHECK_GT(num_output_, 0);
   group_ = this->layer_param_.convolution_param().group();
@@ -71,8 +125,29 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // Handle the parameters: weights and biases.
   // - blobs_[0] holds the filter weights
   // - blobs_[1] holds the biases (optional)
+  vector<int> weight_shape(2);
+  weight_shape[0] = conv_out_channels_;
+  weight_shape[1] = conv_in_channels_ / group_;
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    weight_shape.push_back(kernel_shape_data[i]);
+  }
   bias_term_ = this->layer_param_.convolution_param().bias_term();
+  vector<int> bias_shape(bias_term_, num_output_);
   if (this->blobs_.size() > 0) {
+    CHECK_EQ(1 + bias_term_, this->blobs_.size())
+        << "Incorrect number of weight blobs.";
+    if (weight_shape != this->blobs_[0]->shape()) {
+      Blob<Dtype> weight_shaped_blob(weight_shape);
+      LOG(FATAL) << "Incorrect weight shape: expected shape "
+          << weight_shaped_blob.shape_string() << "; instead, shape was "
+          << this->blobs_[0]->shape_string();
+    }
+    if (bias_term_ && bias_shape != this->blobs_[1]->shape()) {
+      Blob<Dtype> bias_shaped_blob(bias_shape);
+      LOG(FATAL) << "Incorrect bias shape: expected shape "
+          << bias_shaped_blob.shape_string() << "; instead, shape was "
+          << this->blobs_[1]->shape_string();
+    }
     LOG(INFO) << "Skipping parameter initialization";
   } else {
     if (bias_term_) {
@@ -82,20 +157,20 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
     }
     // Initialize and fill the weights:
     // output channels x input channels per-group x kernel height x kernel width
-    this->blobs_[0].reset(new Blob<Dtype>(
-        conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_));
+    this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
     shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
         this->layer_param_.convolution_param().weight_filler()));
     weight_filler->Fill(this->blobs_[0].get());
     // If necessary, initialize and fill the biases.
     if (bias_term_) {
-      vector<int> bias_shape(1, num_output_);
       this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
       shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
           this->layer_param_.convolution_param().bias_filler()));
       bias_filler->Fill(this->blobs_[1].get());
     }
   }
+  kernel_dim_ = this->blobs_[0]->count(1);
+  weight_offset_ = conv_out_channels_ * kernel_dim_ / group_;
   // Propagate gradients to the parameters (as directed by backward pass).
   this->param_propagate_down_.resize(this->blobs_.size(), true);
 }
@@ -103,52 +178,68 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, "
-      << "corresponding to (num, channels, height, width)";
-  num_ = bottom[0]->num();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
-  CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with"
-    " convolution kernel.";
+  const int first_spatial_axis = channel_axis_ + 1;
+  CHECK_EQ(bottom[0]->num_axes(), first_spatial_axis + num_spatial_axes_)
+      << "bottom num_axes may not change.";
+  num_ = bottom[0]->count(0, channel_axis_);
+  CHECK_EQ(bottom[0]->shape(channel_axis_), channels_)
+      << "Input size incompatible with convolution kernel.";
   // TODO: generalize to handle inputs of different shapes.
   for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
-    CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num.";
-    CHECK_EQ(channels_, bottom[bottom_id]->channels())
-        << "Inputs must have same channels.";
-    CHECK_EQ(height_, bottom[bottom_id]->height())
-        << "Inputs must have same height.";
-    CHECK_EQ(width_, bottom[bottom_id]->width())
-        << "Inputs must have same width.";
+    CHECK(bottom[0]->shape() == bottom[bottom_id]->shape())
+        << "All inputs must have the same shape.";
   }
   // Shape the tops.
   compute_output_shape();
+  vector<int> top_shape(bottom[0]->shape().begin(),
+      bottom[0]->shape().begin() + channel_axis_);
+  top_shape.push_back(num_output_);
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    top_shape.push_back(output_shape_[i]);
+  }
   for (int top_id = 0; top_id < top.size(); ++top_id) {
-    top[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
+    top[top_id]->Reshape(top_shape);
   }
   if (reverse_dimensions()) {
-    conv_in_height_ = height_out_;
-    conv_in_width_ = width_out_;
-    conv_out_spatial_dim_ = height_ * width_;
+    conv_out_spatial_dim_ = bottom[0]->count(first_spatial_axis);
   } else {
-    conv_in_height_ = height_;
-    conv_in_width_ = width_;
-    conv_out_spatial_dim_ = height_out_ * width_out_;
+    conv_out_spatial_dim_ = top[0]->count(first_spatial_axis);
   }
-  kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_;
-  weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_;
-  col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_;
+  col_offset_ = kernel_dim_ * conv_out_spatial_dim_;
   output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_;
+  // Setup input dimensions (conv_input_shape_).
+  vector<int> bottom_dim_blob_shape(1, num_spatial_axes_ + 1);
+  conv_input_shape_.Reshape(bottom_dim_blob_shape);
+  int* conv_input_shape_data = conv_input_shape_.mutable_cpu_data();
+  for (int i = 0; i < num_spatial_axes_ + 1; ++i) {
+    if (reverse_dimensions()) {
+      conv_input_shape_data[i] = top[0]->shape(channel_axis_ + i);
+    } else {
+      conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i);
+    }
+  }
   // The im2col result buffer will only hold one image at a time to avoid
   // overly large memory usage. In the special case of 1x1 convolution
   // it goes lazily unused to save memory.
-  if (reverse_dimensions()) {
-    col_buffer_.Reshape(1, kernel_dim_, height_, width_);
-  } else {
-    col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_);
+  col_buffer_shape_.clear();
+  col_buffer_shape_.push_back(kernel_dim_ * group_);
+  const int* input_shape_data = input_shape_.cpu_data() + 1;
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    if (reverse_dimensions()) {
+      col_buffer_shape_.push_back(input_shape_data[i]);
+    } else {
+      col_buffer_shape_.push_back(output_shape_[i]);
+    }
   }
+  col_buffer_.Reshape(col_buffer_shape_);
+  bottom_dim_ = bottom[0]->count(channel_axis_);
+  top_dim_ = top[0]->count(channel_axis_);
+  num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_;
+  num_kernels_col2im_ = reverse_dimensions() ? top_dim_ : bottom_dim_;
   // Set up the all ones "bias multiplier" for adding biases by BLAS
+  out_spatial_dim_ = top[0]->count(first_spatial_axis);
   if (bias_term_) {
-    vector<int> bias_multiplier_shape(1, height_out_ * width_out_);
+    vector<int> bias_multiplier_shape(1, out_spatial_dim_);
     bias_multiplier_.Reshape(bias_multiplier_shape);
     caffe_set(bias_multiplier_.count(), Dtype(1),
         bias_multiplier_.mutable_cpu_data());
@@ -167,7 +258,7 @@ void BaseConvolutionLayer<Dtype>::forward_cpu_gemm(const Dtype* input,
   }
   for (int g = 0; g < group_; ++g) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, conv_out_channels_ /
-        group_, conv_out_spatial_dim_, kernel_dim_ / group_,
+        group_, conv_out_spatial_dim_, kernel_dim_,
         (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g,
         (Dtype)0., output + output_offset_ * g);
   }
@@ -177,7 +268,7 @@ template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::forward_cpu_bias(Dtype* output,
     const Dtype* bias) {
   caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
-      height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(),
+      out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(),
       (Dtype)1., output);
 }
 
@@ -189,7 +280,7 @@ void BaseConvolutionLayer<Dtype>::backward_cpu_gemm(const Dtype* output,
     col_buff = input;
   }
   for (int g = 0; g < group_; ++g) {
-    caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_ / group_,
+    caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_,
         conv_out_spatial_dim_, conv_out_channels_ / group_,
         (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g,
         (Dtype)0., col_buff + col_offset_ * g);
@@ -209,7 +300,7 @@ void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input,
   }
   for (int g = 0; g < group_; ++g) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, conv_out_channels_ / group_,
-        kernel_dim_ / group_, conv_out_spatial_dim_,
+        kernel_dim_, conv_out_spatial_dim_,
         (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g,
         (Dtype)1., weights + weight_offset_ * g);
   }
@@ -218,7 +309,7 @@ void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input,
 template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::backward_cpu_bias(Dtype* bias,
     const Dtype* input) {
-  caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, height_out_ * width_out_, 1.,
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, out_spatial_dim_, 1.,
       input, bias_multiplier_.cpu_data(), 1., bias);
 }
 
@@ -236,7 +327,7 @@ void BaseConvolutionLayer<Dtype>::forward_gpu_gemm(const Dtype* input,
   }
   for (int g = 0; g < group_; ++g) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, conv_out_channels_ /
-        group_, conv_out_spatial_dim_, kernel_dim_ / group_,
+        group_, conv_out_spatial_dim_, kernel_dim_,
         (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g,
         (Dtype)0., output + output_offset_ * g);
   }
@@ -246,7 +337,7 @@ template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::forward_gpu_bias(Dtype* output,
     const Dtype* bias) {
   caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
-      height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(),
+      out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(),
       (Dtype)1., output);
 }
 
@@ -258,7 +349,7 @@ void BaseConvolutionLayer<Dtype>::backward_gpu_gemm(const Dtype* output,
     col_buff = input;
   }
   for (int g = 0; g < group_; ++g) {
-    caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_ / group_,
+    caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_,
         conv_out_spatial_dim_, conv_out_channels_ / group_,
         (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g,
         (Dtype)0., col_buff + col_offset_ * g);
@@ -278,7 +369,7 @@ void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input,
   }
   for (int g = 0; g < group_; ++g) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, conv_out_channels_ / group_,
-        kernel_dim_ / group_, conv_out_spatial_dim_,
+        kernel_dim_, conv_out_spatial_dim_,
         (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g,
         (Dtype)1., weights + weight_offset_ * g);
   }
@@ -287,7 +378,7 @@ void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input,
 template <typename Dtype>
 void BaseConvolutionLayer<Dtype>::backward_gpu_bias(Dtype* bias,
     const Dtype* input) {
-  caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, height_out_ * width_out_, 1.,
+  caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, out_spatial_dim_, 1.,
       input, bias_multiplier_.gpu_data(), 1., bias);
 }
 
index 928ef5e..5cf2697 100644 (file)
@@ -10,10 +10,18 @@ namespace caffe {
 
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::compute_output_shape() {
-  this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_)
-      / this->stride_h_ + 1;
-  this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_)
-      / this->stride_w_ + 1;
+  // input_shape_ + 1 to skip channel axis
+  const int* input_shape_data = this->input_shape_.cpu_data() + 1;
+  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+  const int* stride_data = this->stride_.cpu_data();
+  const int* pad_data = this->pad_.cpu_data();
+  this->output_shape_.clear();
+  for (int i = 0; i < this->num_spatial_axes_; ++i) {
+    const int input_dim = input_shape_data[i];
+    const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i])
+        / stride_data[i] + 1;
+    this->output_shape_.push_back(output_dim);
+  }
 }
 
 template <typename Dtype>
@@ -24,11 +32,11 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     const Dtype* bottom_data = bottom[i]->cpu_data();
     Dtype* top_data = top[i]->mutable_cpu_data();
     for (int n = 0; n < this->num_; ++n) {
-      this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight,
-          top_data + top[i]->offset(n));
+      this->forward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight,
+          top_data + n * this->top_dim_);
       if (this->bias_term_) {
         const Dtype* bias = this->blobs_[1]->cpu_data();
-        this->forward_cpu_bias(top_data + top[i]->offset(n), bias);
+        this->forward_cpu_bias(top_data + n * this->top_dim_, bias);
       }
     }
   }
@@ -47,20 +55,20 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     if (this->bias_term_ && this->param_propagate_down_[1]) {
       Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
       for (int n = 0; n < this->num_; ++n) {
-        this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n));
+        this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_);
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
       for (int n = 0; n < this->num_; ++n) {
         // gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n),
-              top_diff + top[i]->offset(n), weight_diff);
+          this->weight_cpu_gemm(bottom_data + n * this->bottom_dim_,
+              top_diff + n * this->top_dim_, weight_diff);
         }
         // gradient w.r.t. bottom data, if necessary.
         if (propagate_down[i]) {
-          this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight,
-              bottom_diff + bottom[i]->offset(n));
+          this->backward_cpu_gemm(top_diff + n * this->top_dim_, weight,
+              bottom_diff + n * this->bottom_dim_);
         }
       }
     }
index b8a98ff..b429d2b 100644 (file)
@@ -16,11 +16,11 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     const Dtype* bottom_data = bottom[i]->gpu_data();
     Dtype* top_data = top[i]->mutable_gpu_data();
     for (int n = 0; n < this->num_; ++n) {
-      this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight,
-          top_data + top[i]->offset(n));
+      this->forward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight,
+          top_data + n * this->top_dim_);
       if (this->bias_term_) {
         const Dtype* bias = this->blobs_[1]->gpu_data();
-        this->forward_gpu_bias(top_data + top[i]->offset(n), bias);
+        this->forward_gpu_bias(top_data + n * this->top_dim_, bias);
       }
     }
   }
@@ -37,7 +37,7 @@ void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     if (this->bias_term_ && this->param_propagate_down_[1]) {
       Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
       for (int n = 0; n < this->num_; ++n) {
-        this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n));
+        this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_);
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
@@ -46,13 +46,13 @@ void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       for (int n = 0; n < this->num_; ++n) {
         // gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n),
-              top_diff + top[i]->offset(n), weight_diff);
+          this->weight_gpu_gemm(bottom_data + n * this->bottom_dim_,
+              top_diff + n * this->top_dim_, weight_diff);
         }
         // gradient w.r.t. bottom data, if necessary.
         if (propagate_down[i]) {
-          this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight,
-              bottom_diff + bottom[i]->offset(n));
+          this->backward_gpu_gemm(top_diff + n * this->top_dim_, weight,
+              bottom_diff + n * this->bottom_dim_);
         }
       }
     }
index 104d2b9..3514fe2 100644 (file)
@@ -34,14 +34,15 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
   }
 
   // Set the indexing parameters.
-  weight_offset_ = (this->num_output_ / this->group_)
-      * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_;
   bias_offset_ = (this->num_output_ / this->group_);
 
   // Create filter descriptor.
+  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+  const int kernel_h = kernel_shape_data[0];
+  const int kernel_w = kernel_shape_data[1];
   cudnn::createFilterDesc<Dtype>(&filter_desc_,
       this->num_output_ / this->group_, this->channels_ / this->group_,
-      this->kernel_h_, this->kernel_w_);
+      kernel_h, kernel_w);
 
   // Create tensor descriptor(s) for data and corresponding convolution(s).
   for (int i = 0; i < bottom.size(); i++) {
@@ -68,29 +69,36 @@ template <typename Dtype>
 void CuDNNConvolutionLayer<Dtype>::Reshape(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   ConvolutionLayer<Dtype>::Reshape(bottom, top);
-  bottom_offset_ = (this->channels_ / this->group_)
-      * this->height_ * this->width_;
-  top_offset_ = (this->num_output_ / this->group_)
-      * this->height_out_ * this->width_out_;
+  CHECK_EQ(2, this->num_spatial_axes_)
+      << "CuDNNConvolution input must have 2 spatial axes "
+      << "(e.g., height and width). "
+      << "Use 'engine: CAFFE' for general ND convolution.";
+  bottom_offset_ = this->bottom_dim_ / this->group_;
+  top_offset_ = this->top_dim_ / this->group_;
+  const int height = bottom[0]->shape(this->channel_axis_ + 1);
+  const int width = bottom[0]->shape(this->channel_axis_ + 2);
+  const int height_out = top[0]->shape(this->channel_axis_ + 1);
+  const int width_out = top[0]->shape(this->channel_axis_ + 2);
+  const int* pad_data = this->pad_.cpu_data();
+  const int pad_h = pad_data[0];
+  const int pad_w = pad_data[1];
+  const int* stride_data = this->stride_.cpu_data();
+  const int stride_h = stride_data[0];
+  const int stride_w = stride_data[1];
 
   for (int i = 0; i < bottom.size(); i++) {
     cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
         this->num_,
-        this->channels_ / this->group_,
-        this->height_, this->width_,
-        this->channels_ * this->height_ * this->width_,
-        this->height_ * this->width_,
-        this->width_, 1);
+        this->channels_ / this->group_, height, width,
+        this->channels_ * height * width,
+        height * width, width, 1);
     cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
         this->num_,
-        this->num_output_ / this->group_,
-        this->height_out_, this->width_out_,
-        this->num_output_ * this->height_out_ * this->width_out_,
-        this->height_out_ * this->width_out_,
-        this->width_out_, 1);
+        this->num_output_ / this->group_, height_out, width_out,
+        this->num_output_ * this->out_spatial_dim_,
+        this->out_spatial_dim_, width_out, 1);
     cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
-        filter_desc_, this->pad_h_, this->pad_w_,
-        this->stride_h_, this->stride_w_);
+        filter_desc_, pad_h, pad_w, stride_h, stride_w);
   }
 
   // Tensor descriptor for bias.
index b4e802e..6911520 100644 (file)
@@ -14,15 +14,15 @@ __global__ void sync_conv_groups() { }
 template <typename Dtype>
 void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+  const int kernel_h = kernel_shape_data[0];
+  const int kernel_w = kernel_shape_data[1];
+  const size_t workspace_limit_bytes =
+      kernel_h * kernel_w * this->channels_ * sizeof(int) + 1;
+  const Dtype* weight = this->blobs_[0]->gpu_data();
   for (int i = 0; i < bottom.size(); ++i) {
     const Dtype* bottom_data = bottom[i]->gpu_data();
     Dtype* top_data = top[i]->mutable_gpu_data();
-    const Dtype* weight = this->blobs_[0]->gpu_data();
-
-    size_t workspace_limit_bytes = this->kernel_h_ *
-                                   this->kernel_w_ *
-                                   this->channels_ *
-                                   sizeof(int) + 1;
 
     // Forward through cuDNN in parallel over groups.
     for (int g = 0; g < this->group_; g++) {
@@ -69,7 +69,7 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
       CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
             cudnn::dataType<Dtype>::one,
             bottom_descs_[i], bottom_data + bottom_offset_ * g,
-            filter_desc_, weight + weight_offset_ * g,
+            filter_desc_, weight + this->weight_offset_ * g,
             conv_descs_[i],
             algo, workspace, workspaceSizeInBytes,
             cudnn::dataType<Dtype>::zero,
@@ -128,7 +128,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
               top_descs_[i],    top_diff + top_offset_ * g,
               conv_descs_[i],
               cudnn::dataType<Dtype>::one,
-              filter_desc_, weight_diff + weight_offset_ * g));
+              filter_desc_, weight_diff + this->weight_offset_ * g));
       }
 
       // Gradient w.r.t. bottom data.
@@ -139,7 +139,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
         CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
               cudnn::dataType<Dtype>::one,
-              filter_desc_, weight + weight_offset_ * g,
+              filter_desc_, weight + this->weight_offset_ * g,
               top_descs_[i], top_diff + top_offset_ * g,
               conv_descs_[i],
               cudnn::dataType<Dtype>::zero,
index a461296..f1d1abf 100644 (file)
@@ -10,10 +10,18 @@ namespace caffe {
 
 template <typename Dtype>
 void DeconvolutionLayer<Dtype>::compute_output_shape() {
-  this->height_out_ = this->stride_h_ * (this->height_ - 1) + this->kernel_h_
-      - 2 * this->pad_h_;
-  this->width_out_ = this->stride_w_ * (this->width_ - 1) + this->kernel_w_
-      - 2 * this->pad_w_;
+  // input_shape_ + 1 to skip channel axis
+  const int* input_shape_data = this->input_shape_.cpu_data() + 1;
+  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+  const int* stride_data = this->stride_.cpu_data();
+  const int* pad_data = this->pad_.cpu_data();
+  this->output_shape_.clear();
+  for (int i = 0; i < this->num_spatial_axes_; ++i) {
+    const int input_dim = input_shape_data[i];
+    const int output_dim = stride_data[i] * (input_dim - 1)
+        + kernel_shape_data[i] - 2 * pad_data[i];
+    this->output_shape_.push_back(output_dim);
+  }
 }
 
 template <typename Dtype>
@@ -24,11 +32,11 @@ void DeconvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     const Dtype* bottom_data = bottom[i]->cpu_data();
     Dtype* top_data = top[i]->mutable_cpu_data();
     for (int n = 0; n < this->num_; ++n) {
-      this->backward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight,
-          top_data + top[i]->offset(n));
+      this->backward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight,
+          top_data + n * this->top_dim_);
       if (this->bias_term_) {
         const Dtype* bias = this->blobs_[1]->cpu_data();
-        this->forward_cpu_bias(top_data + top[i]->offset(n), bias);
+        this->forward_cpu_bias(top_data + n * this->top_dim_, bias);
       }
     }
   }
@@ -47,21 +55,21 @@ void DeconvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     if (this->bias_term_ && this->param_propagate_down_[1]) {
       Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
       for (int n = 0; n < this->num_; ++n) {
-        this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n));
+        this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_);
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
       for (int n = 0; n < this->num_; ++n) {
         // Gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          this->weight_cpu_gemm(top_diff + top[i]->offset(n),
-              bottom_data + bottom[i]->offset(n), weight_diff);
+          this->weight_cpu_gemm(top_diff + n * this->top_dim_,
+              bottom_data + n * this->bottom_dim_, weight_diff);
         }
         // Gradient w.r.t. bottom data, if necessary, reusing the column buffer
         // we might have just computed above.
         if (propagate_down[i]) {
-          this->forward_cpu_gemm(top_diff + top[i]->offset(n), weight,
-              bottom_diff + bottom[i]->offset(n),
+          this->forward_cpu_gemm(top_diff + n * this->top_dim_, weight,
+              bottom_diff + n * this->bottom_dim_,
               this->param_propagate_down_[0]);
         }
       }
index 8a1eed8..ea83f56 100644 (file)
@@ -16,11 +16,11 @@ void DeconvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     const Dtype* bottom_data = bottom[i]->gpu_data();
     Dtype* top_data = top[i]->mutable_gpu_data();
     for (int n = 0; n < this->num_; ++n) {
-      this->backward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight,
-          top_data + top[i]->offset(n));
+      this->backward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight,
+          top_data + n * this->top_dim_);
       if (this->bias_term_) {
         const Dtype* bias = this->blobs_[1]->gpu_data();
-        this->forward_gpu_bias(top_data + top[i]->offset(n), bias);
+        this->forward_gpu_bias(top_data + n * this->top_dim_, bias);
       }
     }
   }
@@ -39,20 +39,20 @@ void DeconvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     if (this->bias_term_ && this->param_propagate_down_[1]) {
       Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
       for (int n = 0; n < this->num_; ++n) {
-        this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n));
+        this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_);
       }
     }
     if (this->param_propagate_down_[0] || propagate_down[i]) {
       for (int n = 0; n < this->num_; ++n) {
         // gradient w.r.t. weight. Note that we will accumulate diffs.
         if (this->param_propagate_down_[0]) {
-          this->weight_gpu_gemm(top_diff + top[i]->offset(n),
-              bottom_data + bottom[i]->offset(n), weight_diff);
+          this->weight_gpu_gemm(top_diff + n * this->top_dim_,
+              bottom_data + n * this->bottom_dim_, weight_diff);
         }
         // gradient w.r.t. bottom data, if necessary.
         if (propagate_down[i]) {
-          this->forward_gpu_gemm(top_diff + top[i]->offset(n), weight,
-              bottom_diff + bottom[i]->offset(n),
+          this->forward_gpu_gemm(top_diff + this->top_dim_, weight,
+              bottom_diff + n * this->bottom_dim_,
               this->param_propagate_down_[0]);
         }
       }
index 1c80271..595c9db 100644 (file)
@@ -11,54 +11,106 @@ template <typename Dtype>
 void Im2colLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   ConvolutionParameter conv_param = this->layer_param_.convolution_param();
-  CHECK(!conv_param.has_kernel_size() !=
-      !(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "Filter size is kernel_size OR kernel_h and kernel_w; not both";
-  CHECK(conv_param.has_kernel_size() ||
-      (conv_param.has_kernel_h() && conv_param.has_kernel_w()))
-      << "For non-square filters both kernel_h and kernel_w are required.";
-  CHECK((!conv_param.has_pad() && conv_param.has_pad_h()
-      && conv_param.has_pad_w())
-      || (!conv_param.has_pad_h() && !conv_param.has_pad_w()))
-      << "pad is pad OR pad_h and pad_w are required.";
-  CHECK((!conv_param.has_stride() && conv_param.has_stride_h()
-      && conv_param.has_stride_w())
-      || (!conv_param.has_stride_h() && !conv_param.has_stride_w()))
-      << "Stride is stride OR stride_h and stride_w are required.";
-  if (conv_param.has_kernel_size()) {
-    kernel_h_ = kernel_w_ = conv_param.kernel_size();
+  force_nd_im2col_ = conv_param.force_nd_im2col();
+  const int input_num_dims = bottom[0]->shape().size();
+  channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis());
+  const int first_spatial_dim = channel_axis_ + 1;
+  num_spatial_axes_ = input_num_dims - first_spatial_dim;
+  CHECK_GE(num_spatial_axes_, 1);
+  vector<int> dim_blob_shape(1, num_spatial_axes_);
+  // Setup filter kernel dimensions (kernel_shape_).
+  kernel_shape_.Reshape(dim_blob_shape);
+  int* kernel_shape_data = kernel_shape_.mutable_cpu_data();
+  if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "kernel_h & kernel_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.kernel_size_size())
+        << "Either kernel_size or kernel_h/w should be specified; not both.";
+    kernel_shape_data[0] = conv_param.kernel_h();
+    kernel_shape_data[1] = conv_param.kernel_w();
   } else {
-    kernel_h_ = conv_param.kernel_h();
-    kernel_w_ = conv_param.kernel_w();
+    const int num_kernel_dims = conv_param.kernel_size_size();
+    CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_)
+        << "kernel_size must be specified once, or once per spatial dimension "
+        << "(kernel_size specified " << num_kernel_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+      for (int i = 0; i < num_spatial_axes_; ++i) {
+        kernel_shape_data[i] =
+            conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i);
+      }
   }
-  CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-  CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
-  if (!conv_param.has_pad_h()) {
-    pad_h_ = pad_w_ = conv_param.pad();
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero.";
+  }
+  // Setup stride dimensions (stride_).
+  stride_.Reshape(dim_blob_shape);
+  int* stride_data = stride_.mutable_cpu_data();
+  if (conv_param.has_stride_h() || conv_param.has_stride_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "stride_h & stride_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.stride_size())
+        << "Either stride or stride_h/w should be specified; not both.";
+    stride_data[0] = conv_param.stride_h();
+    stride_data[1] = conv_param.stride_w();
   } else {
-    pad_h_ = conv_param.pad_h();
-    pad_w_ = conv_param.pad_w();
+    const int num_stride_dims = conv_param.stride_size();
+    CHECK(num_stride_dims == 0 || num_stride_dims == 1 ||
+          num_stride_dims == num_spatial_axes_)
+        << "stride must be specified once, or once per spatial dimension "
+        << "(stride specified " << num_stride_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+    const int kDefaultStride = 1;
+    for (int i = 0; i < num_spatial_axes_; ++i) {
+      stride_data[i] = (num_stride_dims == 0) ? kDefaultStride :
+          conv_param.stride((num_stride_dims == 1) ? 0 : i);
+      CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero.";
+    }
   }
-  if (!conv_param.has_stride_h()) {
-    stride_h_ = stride_w_ = conv_param.stride();
+  // Setup pad dimensions (pad_).
+  pad_.Reshape(dim_blob_shape);
+  int* pad_data = pad_.mutable_cpu_data();
+  if (conv_param.has_pad_h() || conv_param.has_pad_w()) {
+    CHECK_EQ(num_spatial_axes_, 2)
+        << "pad_h & pad_w can only be used for 2D convolution.";
+    CHECK_EQ(0, conv_param.pad_size())
+        << "Either pad or pad_h/w should be specified; not both.";
+    pad_data[0] = conv_param.pad_h();
+    pad_data[1] = conv_param.pad_w();
   } else {
-    stride_h_ = conv_param.stride_h();
-    stride_w_ = conv_param.stride_w();
+    const int num_pad_dims = conv_param.pad_size();
+    CHECK(num_pad_dims == 0 || num_pad_dims == 1 ||
+          num_pad_dims == num_spatial_axes_)
+        << "pad must be specified once, or once per spatial dimension "
+        << "(pad specified " << num_pad_dims << " times; "
+        << num_spatial_axes_ << " spatial dims);";
+    const int kDefaultPad = 0;
+    for (int i = 0; i < num_spatial_axes_; ++i) {
+      pad_data[i] = (num_pad_dims == 0) ? kDefaultPad :
+          conv_param.pad((num_pad_dims == 1) ? 0 : i);
+    }
   }
 }
 
 template <typename Dtype>
 void Im2colLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, "
-      << "corresponding to (num, channels, height, width)";
-  channels_ = bottom[0]->channels();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
-  top[0]->Reshape(
-      bottom[0]->num(), channels_ * kernel_h_ * kernel_w_,
-      (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1,
-      (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1);
+  vector<int> top_shape = bottom[0]->shape();
+  const int* kernel_shape_data = kernel_shape_.cpu_data();
+  const int* stride_data = stride_.cpu_data();
+  const int* pad_data = pad_.cpu_data();
+  for (int i = 0; i < num_spatial_axes_; ++i) {
+    top_shape[channel_axis_] *= kernel_shape_data[i];
+    const int input_dim = bottom[0]->shape(channel_axis_ + i + 1);
+    const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i])
+        / stride_data[i] + 1;
+    top_shape[channel_axis_ + i + 1] = output_dim;
+  }
+  top[0]->Reshape(top_shape);
+  num_ = bottom[0]->count(0, channel_axis_);
+  bottom_dim_ = bottom[0]->count(channel_axis_);
+  top_dim_ = top[0]->count(channel_axis_);
+
+  channels_ = bottom[0]->shape(channel_axis_);
 }
 
 template <typename Dtype>
@@ -66,10 +118,27 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = top[0]->mutable_cpu_data();
-  for (int n = 0; n < bottom[0]->num(); ++n) {
-    im2col_cpu(bottom_data + bottom[0]->offset(n), channels_, height_,
-        width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
-        stride_h_, stride_w_, top_data + top[0]->offset(n));
+  for (int n = 0; n < num_; ++n) {
+    DCHECK_EQ(bottom[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1);
+    DCHECK_EQ(top[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1);
+    DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_);
+    DCHECK_EQ(pad_.count(), num_spatial_axes_);
+    DCHECK_EQ(stride_.count(), num_spatial_axes_);
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      im2col_cpu(bottom_data + n * bottom_dim_, channels_,
+          bottom[0]->shape(channel_axis_ + 1),
+          bottom[0]->shape(channel_axis_ + 2),
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1],
+          top_data + n * top_dim_);
+    } else {
+      im2col_nd_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_,
+          bottom[0]->shape().data() + channel_axis_,
+          top[0]->shape().data() + channel_axis_,
+          kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(),
+          top_data + n * top_dim_);
+    }
   }
 }
 
@@ -78,10 +147,22 @@ void Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
-  for (int n = 0; n < top[0]->num(); ++n) {
-    col2im_cpu(top_diff + top[0]->offset(n), channels_, height_, width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_,
-        stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n));
+  for (int n = 0; n < num_; ++n) {
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      col2im_cpu(top_diff + n * top_dim_, channels_,
+          bottom[0]->shape(channel_axis_ + 1),
+          bottom[0]->shape(channel_axis_ + 2),
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1],
+          bottom_diff + n * bottom_dim_);
+    } else {
+      col2im_nd_cpu(top_diff + n * top_dim_, num_spatial_axes_,
+          bottom[0]->shape().data() + channel_axis_,
+          top[0]->shape().data() + channel_axis_,
+          kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(),
+          bottom_diff + n * bottom_dim_);
+    }
   }
 }
 
index 9c338b1..cd50762 100644 (file)
@@ -12,10 +12,23 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = top[0]->mutable_gpu_data();
-  for (int n = 0; n < bottom[0]->num(); ++n) {
-    im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_,
-        width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
-        stride_h_, stride_w_, top_data + top[0]->offset(n));
+  const int num_kernels = channels_ * top[0]->count(channel_axis_ + 1);
+  for (int n = 0; n < num_; ++n) {
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      im2col_gpu(bottom_data + n * bottom_dim_, channels_,
+          bottom[0]->shape(channel_axis_ + 1),
+          bottom[0]->shape(channel_axis_ + 2),
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1],
+          top_data + n * top_dim_);
+    } else {
+      im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_,
+          num_kernels, bottom[0]->gpu_shape() + channel_axis_,
+          top[0]->gpu_shape() + channel_axis_,
+          kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+          top_data + n * top_dim_);
+    }
   }
 }
 
@@ -24,10 +37,22 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
-  for (int n = 0; n < top[0]->num(); ++n) {
-    col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_,
-        kernel_h_, kernel_w_, pad_h_, pad_w_,
-        stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n));
+  for (int n = 0; n < num_; ++n) {
+    if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+      col2im_gpu(top_diff + n * top_dim_, channels_,
+          bottom[0]->shape(channel_axis_ + 1),
+          bottom[0]->shape(channel_axis_ + 2),
+          kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+          pad_.cpu_data()[0], pad_.cpu_data()[1],
+          stride_.cpu_data()[0], stride_.cpu_data()[1],
+          bottom_diff + n * bottom_dim_);
+    } else {
+      col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_,
+          bottom[0]->gpu_shape() + channel_axis_,
+          top[0]->gpu_shape() + channel_axis_,
+          kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+          bottom_diff + n * bottom_dim_);
+    }
   }
 }
 
index 86683eb..f52c941 100644 (file)
@@ -508,6 +508,13 @@ message ConvolutionParameter {
   // N independent 3D convolutions, sliding (C/g)-channels
   // filters across the spatial axes (D, H, W) of the input.
   optional int32 axis = 16 [default = 1];
+
+  // Whether to force use of the general ND convolution, even if a specific
+  // implementation for blobs of the appropriate number of spatial dimensions
+  // is available. (Currently, there is only a 2D-specific convolution
+  // implementation; for input blobs with num_axes != 2, this option is
+  // ignored and the ND implementation will be used.)
+  optional bool force_nd_im2col = 17 [default = false];
 }
 
 message DataParameter {
index 67d41ff..9df979a 100644 (file)
@@ -19,54 +19,87 @@ template <typename Dtype>
 void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
     const vector<shared_ptr<Blob<Dtype> > >& weights,
     Blob<Dtype>* out) {
+  const bool has_depth = (out->num_axes() == 5);
+  if (!has_depth) { CHECK_EQ(4, out->num_axes()); }
   // Kernel size, stride, and pad
   int kernel_h, kernel_w;
-  if (conv_param->has_kernel_size()) {
-    kernel_h = kernel_w = conv_param->kernel_size();
-  } else {
+  if (conv_param->has_kernel_h() || conv_param->has_kernel_w()) {
     kernel_h = conv_param->kernel_h();
     kernel_w = conv_param->kernel_w();
+  } else {
+    kernel_h = kernel_w = conv_param->kernel_size(0);
   }
   int pad_h, pad_w;
-  if (!conv_param->has_pad_h()) {
-    pad_h = pad_w = conv_param->pad();
-  } else {
+  if (conv_param->has_pad_h() || conv_param->has_pad_w()) {
     pad_h = conv_param->pad_h();
     pad_w = conv_param->pad_w();
+  } else {
+    pad_h = pad_w = conv_param->pad_size() ? conv_param->pad(0) : 0;
   }
   int stride_h, stride_w;
-  if (!conv_param->has_stride_h()) {
-    stride_h = stride_w = conv_param->stride();
-  } else {
+  if (conv_param->has_stride_h() || conv_param->has_stride_w()) {
     stride_h = conv_param->stride_h();
     stride_w = conv_param->stride_w();
+  } else {
+    stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1;
+  }
+  int kernel_d, pad_d, stride_d;
+  if (has_depth) {
+    kernel_d = kernel_h;
+    stride_d = stride_h;
+    pad_d = pad_h;
+  } else {
+    kernel_d = stride_d = 1;
+    pad_d = 0;
   }
   // Groups
   int groups = conv_param->group();
-  int o_g = out->channels() / groups;
-  int k_g = in->channels() / groups;
+  int o_g = out->shape(1) / groups;
+  int k_g = in->shape(1) / groups;
   int o_head, k_head;
   // Convolution
-  const Dtype* in_data = in->cpu_data();
-  const Dtype* weight_data = weights[0]->cpu_data();
+  vector<int> weight_offset(4 + has_depth);
+  vector<int> in_offset(4 + has_depth);
+  vector<int> out_offset(4 + has_depth);
   Dtype* out_data = out->mutable_cpu_data();
-  for (int n = 0; n < out->num(); n++) {
+  for (int n = 0; n < out->shape(0); n++) {
     for (int g = 0; g < groups; g++) {
       o_head = o_g * g;
       k_head = k_g * g;
       for (int o = 0; o < o_g; o++) {
         for (int k = 0; k < k_g; k++) {
-          for (int y = 0; y < out->height(); y++) {
-            for (int x = 0; x < out->width(); x++) {
-              for (int p = 0; p < kernel_h; p++) {
-                for (int q = 0; q < kernel_w; q++) {
-                  int in_y = y * stride_h - pad_h + p;
-                  int in_x = x * stride_w - pad_w + q;
-                  if (in_y >= 0 && in_y < in->height()
-                    && in_x >= 0 && in_x < in->width()) {
-                    out_data[out->offset(n, o + o_head, y, x)] +=
-                        in_data[in->offset(n, k + k_head, in_y, in_x)]
-                        * weight_data[weights[0]->offset(o + o_head, k, p, q)];
+          for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+            for (int y = 0; y < out->shape(2 + has_depth); y++) {
+              for (int x = 0; x < out->shape(3 + has_depth); x++) {
+                for (int r = 0; r < kernel_d; r++) {
+                  for (int p = 0; p < kernel_h; p++) {
+                    for (int q = 0; q < kernel_w; q++) {
+                      int in_z = z * stride_d - pad_d + r;
+                      int in_y = y * stride_h - pad_h + p;
+                      int in_x = x * stride_w - pad_w + q;
+                      if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1)
+                          && in_y >= 0 && in_y < in->shape(2 + has_depth)
+                          && in_x >= 0 && in_x < in->shape(3 + has_depth)) {
+                        weight_offset[0] = o + o_head;
+                        weight_offset[1] = k;
+                        if (has_depth) { weight_offset[2] = r; }
+                        weight_offset[2 + has_depth] = p;
+                        weight_offset[3 + has_depth] = q;
+                        in_offset[0] = n;
+                        in_offset[1] = k + k_head;
+                        if (has_depth) { in_offset[2] = in_z; }
+                        in_offset[2 + has_depth] = in_y;
+                        in_offset[3 + has_depth] = in_x;
+                        out_offset[0] = n;
+                        out_offset[1] = o + o_head;
+                        if (has_depth) { out_offset[2] = z; }
+                        out_offset[2 + has_depth] = y;
+                        out_offset[3 + has_depth] = x;
+                        out_data[out->offset(out_offset)] +=
+                            in->data_at(in_offset)
+                            * weights[0]->data_at(weight_offset);
+                      }
+                    }
                   }
                 }
               }
@@ -79,11 +112,18 @@ void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
   // Bias
   if (conv_param->bias_term()) {
     const Dtype* bias_data = weights[1]->cpu_data();
-    for (int n = 0; n < out->num(); n++) {
-      for (int o = 0; o < out->channels(); o++) {
-        for (int y = 0; y < out->height(); y++) {
-          for (int x = 0; x < out->width(); x++) {
-            out_data[out->offset(n, o, y, x)] += bias_data[o];
+    for (int n = 0; n < out->shape(0); n++) {
+      for (int o = 0; o < out->shape(1); o++) {
+        for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+          for (int y = 0; y < out->shape(2 + has_depth); y++) {
+            for (int x = 0; x < out->shape(3 + has_depth); x++) {
+              out_offset[0] = n;
+              out_offset[1] = o;
+              if (has_depth) { out_offset[2] = z; }
+              out_offset[2 + has_depth] = y;
+              out_offset[3 + has_depth] = x;
+              out_data[out->offset(out_offset)] += bias_data[o];
+            }
           }
         }
       }
@@ -150,8 +190,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
@@ -188,8 +228,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("constant");
@@ -217,13 +257,98 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) {
   }
 }
 
+TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  const int kNumOutput = 3;
+  convolution_param->set_num_output(kNumOutput);
+  convolution_param->set_axis(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  shared_ptr<Layer<Dtype> > layer(
+      new ConvolutionLayer<Dtype>(layer_param));
+  vector<int> top_shape = this->blob_bottom_->shape();
+  top_shape[3] = kNumOutput;
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(top_shape, this->blob_top_->shape());
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  vector<int> weight_offset(2);
+  const Blob<Dtype>* weight = layer->blobs()[0].get();
+  const Blob<Dtype>* bias = layer->blobs()[1].get();
+  const int num = this->blob_top_->count(3);
+  const int dim = this->blob_top_->shape(3);
+  const int bottom_dim = this->blob_bottom_->shape(3);
+  for (int n = 0; n < num; ++n) {
+    for (int d = 0; d < dim; ++d) {
+      weight_offset[0] = d;
+      Dtype value = bias->cpu_data()[d];
+      for (int bottom_d = 0; bottom_d < bottom_dim; ++bottom_d) {
+        weight_offset[1] = bottom_d;
+        value += weight->data_at(weight_offset) *
+                 this->blob_bottom_->cpu_data()[n * bottom_dim + bottom_d];
+      }
+      EXPECT_NEAR(value, this->blob_top_->cpu_data()[n * dim + d], 1e-4);
+    }
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 5;
+  bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2);
+  bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3);
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  shared_ptr<Layer<Dtype> > layer(
+      new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+  caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_2_));
+  top_data = this->blob_top_2_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+}
+
 TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(1);
-  convolution_param->set_stride(1);
+  convolution_param->add_kernel_size(1);
+  convolution_param->add_stride(1);
   convolution_param->set_num_output(4);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("constant");
@@ -249,8 +374,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(3);
   convolution_param->set_group(3);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
@@ -288,8 +413,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(1);
   convolution_param->set_bias_term(false);
   shared_ptr<Layer<Dtype> > layer(
@@ -350,14 +475,11 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) {
   convolution_param->set_bias_term(false);
   layer.reset(new ConvolutionLayer<Dtype>(layer_param));
   layer->blobs().resize(1);
-  layer->blobs()[0].reset(new Blob<Dtype>(1, 3, 1, 3));
+  layer->blobs()[0].reset(new Blob<Dtype>(1, 1, 1, 3));
   Dtype* weights_2 = layer->blobs()[0]->mutable_cpu_data();
-  for (int c = 0; c < 3; ++c) {
-    int i = c * 3;  // 1 x 3 filter
-    weights_2[i +  0] = -1;
-    weights_2[i +  1] =  0;
-    weights_2[i +  2] =  1;
-  }
+  weights_2[0] = -1;
+  weights_2[1] =  0;
+  weights_2[2] =  1;
   layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec);
   layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec);
   // Test equivalence of full and separable filters.
@@ -368,6 +490,124 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) {
   }
 }
 
+TYPED_TEST(ConvolutionLayerTest, TestNDAgainst2D) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kernel_h = 11;
+  const int kernel_w = 13;
+  vector<int> bottom_shape(4);
+  bottom_shape[0] = 15;
+  bottom_shape[1] = 18;
+  bottom_shape[2] = kernel_h * 2;
+  bottom_shape[3] = kernel_w * 2;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_num_output(12);
+  convolution_param->set_bias_term(false);
+  convolution_param->set_group(6);
+  convolution_param->set_kernel_h(kernel_h);
+  convolution_param->set_kernel_w(kernel_w);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  Blob<Dtype> weights;
+  Blob<Dtype> top_diff;
+  // Shape and fill weights and top_diff.
+  bool copy_diff;
+  bool reshape;
+  {
+    ConvolutionLayer<Dtype> layer(layer_param);
+    layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    top_diff.ReshapeLike(*this->blob_top_);
+    filler.Fill(&top_diff);
+    ASSERT_EQ(1, layer.blobs().size());
+    copy_diff = false; reshape = true;
+    weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape);
+  }
+  vector<bool> propagate_down(1, true);
+  Blob<Dtype> result_2d;
+  Blob<Dtype> backward_result_2d;
+  Blob<Dtype> backward_weight_result_2d;
+  // Test with 2D im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_2d.
+    convolution_param->set_force_nd_im2col(false);
+    ConvolutionLayer<Dtype> layer_2d(layer_param);
+    layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_2d.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_2d.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_2d.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape);
+  }
+  Blob<Dtype> result_nd;
+  Blob<Dtype> backward_result_nd;
+  Blob<Dtype> backward_weight_result_nd;
+  // Test with ND im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_nd.
+    convolution_param->set_force_nd_im2col(true);
+    ConvolutionLayer<Dtype> layer_nd(layer_param);
+    layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_nd.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_nd.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_nd.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape);
+  }
+  ASSERT_EQ(result_nd.count(), result_2d.count());
+  for (int i = 0; i < result_2d.count(); ++i)  {
+    EXPECT_EQ(result_2d.cpu_data()[i], result_nd.cpu_data()[i]);
+  }
+  ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count());
+  for (int i = 0; i < backward_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_result_2d.cpu_diff()[i],
+              backward_result_nd.cpu_diff()[i]);
+  }
+  ASSERT_EQ(backward_weight_result_nd.count(),
+            backward_weight_result_2d.count());
+  for (int i = 0; i < backward_weight_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i],
+              backward_weight_result_nd.cpu_diff()[i]);
+  }
+}
+
 TYPED_TEST(ConvolutionLayerTest, TestGradient) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
@@ -375,8 +615,36 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) {
       layer_param.mutable_convolution_param();
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(2);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  ConvolutionLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestGradient3D) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 5;
+  bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2);
+  bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3);
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(2);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("gaussian");
@@ -393,8 +661,8 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) {
       layer_param.mutable_convolution_param();
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
-  convolution_param->set_kernel_size(1);
-  convolution_param->set_stride(1);
+  convolution_param->add_kernel_size(1);
+  convolution_param->add_stride(1);
   convolution_param->set_num_output(2);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("gaussian");
@@ -409,8 +677,8 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(3);
   convolution_param->set_group(3);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
@@ -472,8 +740,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
@@ -509,8 +777,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("constant");
@@ -542,8 +810,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(3);
   convolution_param->set_group(3);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
@@ -581,8 +849,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(1);
   convolution_param->set_bias_term(false);
   shared_ptr<Layer<TypeParam> > layer(
@@ -643,14 +911,11 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
   convolution_param->set_bias_term(false);
   layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
   layer->blobs().resize(1);
-  layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 1, 3));
+  layer->blobs()[0].reset(new Blob<TypeParam>(1, 1, 1, 3));
   TypeParam* weights_2 = layer->blobs()[0]->mutable_cpu_data();
-  for (int c = 0; c < 3; ++c) {
-    int i = c * 3;  // 1 x 3 filter
-    weights_2[i +  0] = -1;
-    weights_2[i +  1] =  0;
-    weights_2[i +  2] =  1;
-  }
+  weights_2[0] = -1;
+  weights_2[1] =  0;
+  weights_2[2] =  1;
   layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec);
   layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec);
   // Test equivalence of full and separable filters.
@@ -667,8 +932,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) {
       layer_param.mutable_convolution_param();
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(2);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("gaussian");
@@ -682,8 +947,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(3);
   convolution_param->set_group(3);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
index fc63d5e..770e7b2 100644 (file)
@@ -58,8 +58,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSetup) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
@@ -96,8 +96,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSimpleDeconvolution) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   convolution_param->set_num_output(4);
   convolution_param->mutable_weight_filler()->set_type("constant");
   convolution_param->mutable_weight_filler()->set_value(1);
@@ -144,8 +144,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient) {
       layer_param.mutable_convolution_param();
   this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
   this->blob_top_vec_.push_back(this->blob_top_2_);
-  convolution_param->set_kernel_size(2);
-  convolution_param->set_stride(1);
+  convolution_param->add_kernel_size(2);
+  convolution_param->add_stride(1);
   convolution_param->set_num_output(1);
   convolution_param->mutable_weight_filler()->set_type("gaussian");
   convolution_param->mutable_bias_filler()->set_type("gaussian");
@@ -155,4 +155,151 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient) {
       this->blob_top_vec_);
 }
 
+TYPED_TEST(DeconvolutionLayerTest, TestNDAgainst2D) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kernel_h = 11;
+  const int kernel_w = 13;
+  vector<int> bottom_shape(4);
+  bottom_shape[0] = 15;
+  bottom_shape[1] = 12;
+  bottom_shape[2] = kernel_h * 2;
+  bottom_shape[3] = kernel_w * 2;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_num_output(18);
+  convolution_param->set_bias_term(false);
+  convolution_param->set_group(6);
+  convolution_param->set_kernel_h(kernel_h);
+  convolution_param->set_kernel_w(kernel_w);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  Blob<Dtype> weights;
+  Blob<Dtype> top_diff;
+  // Shape and fill weights and top_diff.
+  bool copy_diff;
+  bool reshape;
+  {
+    DeconvolutionLayer<Dtype> layer(layer_param);
+    layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    top_diff.ReshapeLike(*this->blob_top_);
+    filler.Fill(&top_diff);
+    ASSERT_EQ(1, layer.blobs().size());
+    copy_diff = false; reshape = true;
+    weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape);
+  }
+  vector<bool> propagate_down(1, true);
+  Blob<Dtype> result_2d;
+  Blob<Dtype> backward_result_2d;
+  Blob<Dtype> backward_weight_result_2d;
+  // Test with 2D im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_2d.
+    convolution_param->set_force_nd_im2col(false);
+    DeconvolutionLayer<Dtype> layer_2d(layer_param);
+    layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_2d.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_2d.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_2d.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape);
+  }
+  Blob<Dtype> result_nd;
+  Blob<Dtype> backward_result_nd;
+  Blob<Dtype> backward_weight_result_nd;
+  // Test with ND im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_nd.
+    convolution_param->set_force_nd_im2col(true);
+    DeconvolutionLayer<Dtype> layer_nd(layer_param);
+    layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_nd.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_nd.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_nd.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape);
+  }
+  ASSERT_EQ(result_nd.count(), result_2d.count());
+  for (int i = 0; i < result_2d.count(); ++i)  {
+    EXPECT_EQ(result_2d.cpu_data()[i], result_nd.cpu_data()[i]);
+  }
+  ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count());
+  for (int i = 0; i < backward_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_result_2d.cpu_diff()[i],
+              backward_result_nd.cpu_diff()[i]);
+  }
+  ASSERT_EQ(backward_weight_result_nd.count(),
+            backward_weight_result_2d.count());
+  for (int i = 0; i < backward_weight_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i],
+              backward_weight_result_nd.cpu_diff()[i]);
+  }
+}
+
+TYPED_TEST(DeconvolutionLayerTest, TestGradient3D) {
+  typedef typename TypeParam::Dtype Dtype;
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 2;
+  bottom_shape[3] = 3;
+  bottom_shape[4] = 2;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(2);
+  convolution_param->add_stride(2);
+  convolution_param->add_pad(1);
+  convolution_param->set_num_output(2);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  DeconvolutionLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
 }  // namespace caffe
index 0017ac2..f0b75fc 100644 (file)
@@ -22,6 +22,12 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
     const int height_col, const int width_col,
     Dtype* data_col);
 
+template <typename Dtype, int num_axes>
+__global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col);
+
 extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
 
 template <typename Dtype>
@@ -30,11 +36,18 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
   Im2colKernelTest()
         // big so launches > 1024 threads
       : blob_bottom_(new Blob<Dtype>(5, 500, 10, 10)),
+        blob_kernel_shape_(new Blob<int>()),
+        blob_stride_(new Blob<int>()),
+        blob_pad_(new Blob<int>()),
         blob_top_(new Blob<Dtype>()),
         blob_top_cpu_(new Blob<Dtype>()) {
     FillerParameter filler_param;
     GaussianFiller<Dtype> filler(filler_param);
     filler.Fill(this->blob_bottom_);
+    vector<int> dim_blob_shape(1, 2);
+    blob_kernel_shape_->Reshape(dim_blob_shape);
+    blob_stride_->Reshape(dim_blob_shape);
+    blob_pad_->Reshape(dim_blob_shape);
 
     height_ = blob_bottom_->height();
     width_ = blob_bottom_->width();
@@ -44,14 +57,26 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
     kernel_size_ = 3;
     height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1;
     width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1;
+
+    for (int i = 0; i < 2; ++i) {
+      blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_;
+      blob_stride_->mutable_cpu_data()[i] = stride_;
+      blob_pad_->mutable_cpu_data()[i] = pad_;
+    }
   }
 
   virtual ~Im2colKernelTest() {
-      delete blob_bottom_;
-      delete blob_top_;
-      delete blob_top_cpu_;
+    delete blob_bottom_;
+    delete blob_top_;
+    delete blob_top_cpu_;
+    delete blob_kernel_shape_;
+    delete blob_stride_;
+    delete blob_pad_;
   }
 
+  Blob<int>* const blob_kernel_shape_;
+  Blob<int>* const blob_stride_;
+  Blob<int>* const blob_pad_;
   Blob<Dtype>* const blob_bottom_;
   Blob<Dtype>* const blob_top_;
   Blob<Dtype>* const blob_top_cpu_;
@@ -67,7 +92,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
 
 TYPED_TEST_CASE(Im2colKernelTest, TestDtypes);
 
-TYPED_TEST(Im2colKernelTest, TestGPU) {
+TYPED_TEST(Im2colKernelTest, Test2D) {
   // Reshape the blobs to correct size for im2col output
   this->blob_top_->Reshape(this->blob_bottom_->num(),
           this->channels_ * this->kernel_size_ * this->kernel_size_,
@@ -122,4 +147,58 @@ TYPED_TEST(Im2colKernelTest, TestGPU) {
   }
 }
 
+TYPED_TEST(Im2colKernelTest, TestND) {
+  // Reshape the blobs to correct size for im2col output
+  this->blob_top_->Reshape(this->blob_bottom_->num(),
+      this->channels_ * this->kernel_size_ * this->kernel_size_,
+      this->height_col_,
+      this->width_col_);
+
+  this->blob_top_cpu_->ReshapeLike(*this->blob_top_);
+
+  const TypeParam* bottom_data_cpu = this->blob_bottom_->cpu_data();
+  TypeParam* top_data_cpu = this->blob_top_cpu_->mutable_cpu_data();
+
+  // CPU Version
+  for (int n = 0; n < this->blob_bottom_->num(); ++n) {
+    im2col_nd_cpu(bottom_data_cpu + this->blob_bottom_->offset(n), 2,
+        this->blob_bottom_->shape().data() + 1,
+        this->blob_top_cpu_->shape().data() + 1,
+        this->blob_kernel_shape_->cpu_data(),
+        this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(),
+        top_data_cpu + this->blob_top_cpu_->offset(n));
+  }
+
+  // GPU version
+  int num_kernels = this->channels_ * this->height_col_ * this->width_col_;
+  int default_grid_dim = CAFFE_GET_BLOCKS(num_kernels);
+  const TypeParam* bottom_data_gpu = this->blob_bottom_->gpu_data();
+
+  // Launch with different grid sizes
+  for (int grid_div = 2; grid_div <= 8; grid_div++) {
+    for (int n = 0; n < this->blob_bottom_->num(); ++n) {
+      const int grid_dim = default_grid_dim / grid_div;
+      TypeParam* top_data_gpu = this->blob_top_->mutable_gpu_data();
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      im2col_nd_gpu_kernel<TypeParam, 2><<<grid_dim, CAFFE_CUDA_NUM_THREADS>>>(
+          num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n),
+          this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1,
+          this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(),
+          this->blob_stride_->gpu_data(),
+          top_data_gpu + this->blob_top_->offset(n));
+      CUDA_POST_KERNEL_CHECK;
+    }
+
+    // Compare results against CPU version
+    for (int i = 0; i < this->blob_top_->count(); ++i) {
+      TypeParam cpuval = top_data_cpu[i];
+      TypeParam gpuval = this->blob_top_->cpu_data()[i];
+      EXPECT_EQ(cpuval, gpuval);
+      if (cpuval != gpuval) {
+        break;
+      }
+    }
+  }
+}
+
 }  // namespace caffe
index f50abe1..293aa26 100644 (file)
@@ -21,6 +21,7 @@ class Im2colLayerTest : public MultiDeviceTest<TypeParam> {
       : blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
         blob_top_(new Blob<Dtype>()) {
     // fill the values
+    Caffe::set_random_seed(1701);
     FillerParameter filler_param;
     GaussianFiller<Dtype> filler(filler_param);
     filler.Fill(this->blob_bottom_);
@@ -41,8 +42,8 @@ TYPED_TEST(Im2colLayerTest, TestSetup) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   Im2colLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_->num(), 2);
@@ -56,8 +57,8 @@ TYPED_TEST(Im2colLayerTest, TestForward) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   Im2colLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
@@ -73,14 +74,27 @@ TYPED_TEST(Im2colLayerTest, TestGradient) {
   LayerParameter layer_param;
   ConvolutionParameter* convolution_param =
       layer_param.mutable_convolution_param();
-  convolution_param->set_kernel_size(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
   Im2colLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-2);
   checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
       this->blob_top_vec_);
 }
 
+TYPED_TEST(Im2colLayerTest, TestGradientForceND) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_force_nd_im2col(true);
+  Im2colLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-2);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
 
 TYPED_TEST(Im2colLayerTest, TestRect) {
   typedef typename TypeParam::Dtype Dtype;
@@ -89,7 +103,7 @@ TYPED_TEST(Im2colLayerTest, TestRect) {
       layer_param.mutable_convolution_param();
   convolution_param->set_kernel_h(5);
   convolution_param->set_kernel_w(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_stride(2);
   Im2colLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
@@ -108,7 +122,7 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) {
       layer_param.mutable_convolution_param();
   convolution_param->set_kernel_h(5);
   convolution_param->set_kernel_w(3);
-  convolution_param->set_stride(2);
+  convolution_param->add_stride(2);
   Im2colLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-2);
   checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
index c48f31f..b0a7be5 100644 (file)
@@ -1,6 +1,7 @@
 #include <cmath>
 #include <cstdlib>
 #include <cstring>
+#include <vector>
 
 #include "caffe/util/im2col.hpp"
 #include "caffe/util/math_functions.hpp"
@@ -45,6 +46,98 @@ template void im2col_cpu<double>(const double* data_im, const int channels,
     const int stride_w, double* data_col);
 
 template <typename Dtype>
+inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
+    const int num_spatial_axes, const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_output) {
+  if (!im2col) {
+    int im_size = im_shape[0];
+    for (int i = 0; i < num_spatial_axes; ++i) {
+      im_size *= im_shape[1 + i];
+    }
+    caffe_set(im_size, Dtype(0), data_output);
+  }
+  int kernel_size = 1;
+  for (int i = 0; i < num_spatial_axes; ++i) {
+    kernel_size *= kernel_shape[i];
+  }
+  const int channels_col = col_shape[0];
+  vector<int> d_offset(num_spatial_axes, 0);
+  vector<int> d_iter(num_spatial_axes, 0);
+  for (int c = 0; c < channels_col; ++c) {
+    // Loop over spatial axes in reverse order to compute a per-axis offset.
+    int offset = c;
+    for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) {
+      if (d_i < num_spatial_axes - 1) {
+        offset /= kernel_shape[d_i + 1];
+      }
+      d_offset[d_i] = offset % kernel_shape[d_i];
+    }
+    for (bool incremented = true; incremented; ) {
+      // Loop over spatial axes in forward order to compute the indices in the
+      // image and column, and whether the index lies in the padding.
+      int index_col = c;
+      int index_im = c / kernel_size;
+      bool is_padding = false;
+      for (int d_i = 0; d_i < num_spatial_axes; ++d_i) {
+        const int d = d_iter[d_i];
+        const int d_pad = d * stride[d_i] - pad[d_i] + d_offset[d_i];
+        is_padding |= d_pad < 0 || d_pad >= im_shape[d_i + 1];
+        index_col *= col_shape[d_i + 1];
+        index_col += d;
+        index_im *= im_shape[d_i + 1];
+        index_im += d_pad;
+      }
+      if (im2col) {
+        if (is_padding) {
+          data_output[index_col] = 0;
+        } else {
+          data_output[index_col] = data_input[index_im];
+        }
+      } else if (!is_padding) {  // col2im
+        data_output[index_im] += data_input[index_col];
+      }
+      // Loop over spatial axes in reverse order to choose an index,
+      // like counting.
+      incremented = false;
+      for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) {
+        const int d_max = col_shape[d_i + 1];
+        DCHECK_LT(d_iter[d_i], d_max);
+        if (d_iter[d_i] == d_max - 1) {
+          d_iter[d_i] = 0;
+        } else {  // d_iter[d_i] < d_max - 1
+          ++d_iter[d_i];
+          incremented = true;
+          break;
+        }
+      }
+    }  // while(incremented) {
+  }  // for (int c = 0; c < channels_col; ++c) {
+}
+
+template <typename Dtype>
+void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col) {
+  const bool kIm2Col = true;
+  im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape,
+                  kernel_shape, pad, stride, data_col);
+}
+
+// Explicit instantiation
+template void im2col_nd_cpu<float>(const float* data_im,
+    const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    float* data_col);
+template void im2col_nd_cpu<double>(const double* data_im,
+    const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    double* data_col);
+
+template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int patch_h, const int patch_w,
     const int pad_h, const int pad_w,
@@ -80,4 +173,27 @@ template void col2im_cpu<double>(const double* data_col, const int channels,
     const int pad_h, const int pad_w, const int stride_h,
     const int stride_w, double* data_im);
 
+template <typename Dtype>
+void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_im) {
+  const bool kIm2Col = false;
+  im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape,
+                     kernel_shape, pad, stride, data_im);
+}
+
+// Explicit instantiation
+template void col2im_nd_cpu<float>(const float* data_col,
+    const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    float* data_im);
+template void col2im_nd_cpu<double>(const double* data_col,
+    const int num_spatial_axes,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    double* data_im);
+
+
 }  // namespace caffe
index c90f93e..5a478ba 100644 (file)
@@ -59,7 +59,6 @@ void im2col_gpu(const Dtype* data_im, const int channels,
   CUDA_POST_KERNEL_CHECK;
 }
 
-
 // Explicit instantiation
 template void im2col_gpu<float>(const float* data_im, const int channels,
     const int height, const int width, const int kernel_h, const int kernel_w,
@@ -70,6 +69,156 @@ template void im2col_gpu<double>(const double* data_im, const int channels,
     const int pad_h, const int pad_w, const int stride_h, const int stride_w,
     double* data_col);
 
+template <typename Dtype, int num_axes>
+__global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col) {
+  int d_temp[num_axes];  // NOLINT(runtime/arrays)
+  int d_iter[num_axes];  // NOLINT(runtime/arrays)
+  int i;
+  CUDA_KERNEL_LOOP(index, n) {
+    // Initialize channel_in, computed in the loop below, with intermediate
+    // computations used to compute the spatial indices.
+    int channel_in = index;
+    int channel_out = 1;
+    for (i = num_axes - 1; i >= 0; --i) {
+      d_temp[i] = channel_in % col_shape[i + 1];
+      channel_in /= col_shape[i + 1];
+      channel_out *= kernel_shape[i];
+    }
+    channel_out *= channel_in;
+    int data_col_inc = 1;
+    for (i = 0; i < num_axes; ++i) {
+      channel_out *= col_shape[i + 1];
+      channel_out += d_temp[i];
+      d_temp[i] = d_temp[i] * stride[i] - pad[i];
+      channel_in *= im_shape[i + 1];
+      channel_in += d_temp[i];
+      data_col_inc *= col_shape[i + 1];
+      d_iter[i] = 0;
+    }
+    Dtype* data_col_ptr = data_col + channel_out;
+    const Dtype* data_im_ptr = data_im + channel_in;
+    bool incremented;
+    do {
+      bool in_range = true;
+      for (i = 0; i < num_axes; ++i) {
+        const int d_iter_im = d_iter[i] + d_temp[i];
+        in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1];
+        if (!in_range) { break; }
+      }
+      if (in_range) {
+        int data_im_offset = d_iter[0];
+        for (i = 1; i < num_axes; ++i) {
+          data_im_offset *= im_shape[i + 1];
+          data_im_offset += d_iter[i];
+        }
+        *data_col_ptr = data_im_ptr[data_im_offset];
+      } else {
+        *data_col_ptr = 0;
+      }
+      data_col_ptr += data_col_inc;
+      incremented = false;
+      for (i = num_axes - 1; i >= 0; --i) {
+        const int d_max = kernel_shape[i];
+        if (d_iter[i] == d_max - 1) {
+          d_iter[i] = 0;
+        } else {  // d_iter[i] < d_max - 1
+          ++d_iter[i];
+          incremented = true;
+          break;
+        }
+      }  // for (int i = num_axes - 1; i >= 0; --i)
+    } while (incremented);  // do
+  }  // CUDA_KERNEL_LOOP(index, n)
+}
+
+template <typename Dtype>
+void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
+    const int num_kernels, const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_col) {
+  switch (num_spatial_axes) {
+  case 1:
+    im2col_nd_gpu_kernel<Dtype, 1>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 2:
+    im2col_nd_gpu_kernel<Dtype, 2>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 3:
+    im2col_nd_gpu_kernel<Dtype, 3>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 4:
+    im2col_nd_gpu_kernel<Dtype, 4>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 5:
+    im2col_nd_gpu_kernel<Dtype, 5>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 6:
+    im2col_nd_gpu_kernel<Dtype, 6>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 7:
+    im2col_nd_gpu_kernel<Dtype, 7>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 8:
+    im2col_nd_gpu_kernel<Dtype, 8>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 9:
+    im2col_nd_gpu_kernel<Dtype, 9>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  case 10:
+    im2col_nd_gpu_kernel<Dtype, 10>  // NOLINT_NEXT_LINE(whitespace/operators)
+        <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+        num_kernels, data_im, im_shape, col_shape,
+        kernel_shape, pad, stride, data_col);
+    break;
+  default:
+    LOG(FATAL) << "im2col_nd_gpu does not support computation with "
+               << num_spatial_axes << " spatial axes";
+  }
+  CUDA_POST_KERNEL_CHECK;
+}
+
+// Explicit instantiation
+template void im2col_nd_gpu<float>(const float* data_im,
+    const int num_spatial_axes, const int col_size,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    float* data_col);
+template void im2col_nd_gpu<double>(const double* data_im,
+    const int num_spatial_axes, const int col_size,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    double* data_col);
+
 template <typename Dtype>
 __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
     const int height, const int width, const int channels,
@@ -141,4 +290,159 @@ template void col2im_gpu<double>(const double* data_col, const int channels,
     const int pad_h, const int pad_w, const int stride_h,
     const int stride_w, double* data_im);
 
+template <typename Dtype, int num_axes>
+__global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_im) {
+  int d_im[num_axes];  // NOLINT(runtime/arrays)
+  int d_col_iter[num_axes];  // NOLINT(runtime/arrays)
+  int d_col_start[num_axes];  // NOLINT(runtime/arrays)
+  int d_col_end[num_axes];  // NOLINT(runtime/arrays)
+  CUDA_KERNEL_LOOP(index, n) {
+    // Initialize channel_in, computed in the loop below, with intermediate
+    // computations used to compute the spatial indices.
+    int channel_im = index;
+    // Calculate d_im (image dimensions).
+    for (int i = num_axes - 1; i >= 0; --i) {
+      d_im[i] = channel_im % im_shape[i + 1] + pad[i];
+      channel_im /= im_shape[i + 1];
+    }
+    // Calculate col start/end indices.
+    bool done = false;
+    for (int i = 0; i < num_axes; ++i) {
+      d_col_start[i] = d_col_iter[i] =
+          (d_im[i] < kernel_shape[i]) ?
+          0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1;
+      d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]);
+      if (d_col_start[i] >= d_col_end[i]) {
+        // Skip computation if the dimension is 0 at any spatial axis --
+        // final val will be 0.
+        data_im[index] = 0;
+        done = true;
+        break;  // for (int i = 0; i < num_axes; ++i)
+      }
+    }
+    if (done) {
+      continue;  // CUDA_KERNEL_LOOP(index, n)
+    }
+    // Loop over the col to compute the output val.
+    Dtype val = 0;
+    bool incremented = true;
+    do {
+      // Compute the final offset.
+      int final_offset = 0;
+      int kernel_shape_prod = 1;
+      for (int i = num_axes - 1; i >= 0; --i) {
+        final_offset +=
+            (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod;
+        kernel_shape_prod *= kernel_shape[i];
+      }
+      final_offset += kernel_shape_prod * channel_im;
+      for (int i = 0; i < num_axes; ++i) {
+        final_offset *= col_shape[i + 1];
+        final_offset += d_col_iter[i];
+      }
+      val += data_col[final_offset];
+      incremented = false;
+      for (int i = num_axes - 1; i >= 0; --i) {
+        const int d_max = d_col_end[i];
+        if (d_col_iter[i] == d_max - 1) {
+          d_col_iter[i] = d_col_start[i];
+        } else {  // d_col_iter[i] < d_max - 1
+          ++d_col_iter[i];
+          incremented = true;
+          break;  // for (int i = num_axes - 1; i >= 0; --i)
+        }
+      }  // for (int i = num_axes - 1; i >= 0; --i)
+    }  while (incremented);
+    data_im[index] = val;
+  }  // CUDA_KERNEL_LOOP(index, n)
+}
+
+template <typename Dtype>
+void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
+    const int im_size, const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    Dtype* data_im) {
+  switch (num_spatial_axes) {
+  case 1:
+    col2im_nd_gpu_kernel<Dtype, 1>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 2:
+    col2im_nd_gpu_kernel<Dtype, 2>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 3:
+    col2im_nd_gpu_kernel<Dtype, 3>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 4:
+    col2im_nd_gpu_kernel<Dtype, 4>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 5:
+    col2im_nd_gpu_kernel<Dtype, 5>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 6:
+    col2im_nd_gpu_kernel<Dtype, 6>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 7:
+    col2im_nd_gpu_kernel<Dtype, 7>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 8:
+    col2im_nd_gpu_kernel<Dtype, 8>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 9:
+    col2im_nd_gpu_kernel<Dtype, 9>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  case 10:
+    col2im_nd_gpu_kernel<Dtype, 10>  // NOLINT_NEXT_LINE(whitespace/operators)
+          <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
+          im_size, data_col, im_shape, col_shape,
+          kernel_shape, pad, stride, data_im);
+    break;
+  default:
+    LOG(FATAL) << "col2im_nd_gpu does not support computation with "
+               << num_spatial_axes << " spatial axes";
+  }
+  CUDA_POST_KERNEL_CHECK;
+}
+
+// Explicit instantiation
+template void col2im_nd_gpu<float>(const float* data_col,
+    const int num_spatial_axes, const int im_size,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    float* data_im);
+template void col2im_nd_gpu<double>(const double* data_col,
+    const int num_spatial_axes, const int im_size,
+    const int* im_shape, const int* col_shape,
+    const int* kernel_shape, const int* pad, const int* stride,
+    double* data_im);
+
 }  // namespace caffe
index 92e5cf5..ac379e5 100644 (file)
@@ -193,7 +193,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection,
     }
     if (v0_layer_param.has_pad()) {
       if (type == "conv") {
-        layer_param->mutable_convolution_param()->set_pad(v0_layer_param.pad());
+        layer_param->mutable_convolution_param()->add_pad(v0_layer_param.pad());
       } else if (type == "pool") {
         layer_param->mutable_pooling_param()->set_pad(v0_layer_param.pad());
       } else {
@@ -203,7 +203,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection,
     }
     if (v0_layer_param.has_kernelsize()) {
       if (type == "conv") {
-        layer_param->mutable_convolution_param()->set_kernel_size(
+        layer_param->mutable_convolution_param()->add_kernel_size(
             v0_layer_param.kernelsize());
       } else if (type == "pool") {
         layer_param->mutable_pooling_param()->set_kernel_size(
@@ -224,7 +224,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection,
     }
     if (v0_layer_param.has_stride()) {
       if (type == "conv") {
-        layer_param->mutable_convolution_param()->set_stride(
+        layer_param->mutable_convolution_param()->add_stride(
             v0_layer_param.stride());
       } else if (type == "pool") {
         layer_param->mutable_pooling_param()->set_stride(