SliceLayer: generalized Blob axes
authorJeff Donahue <jeff.donahue@gmail.com>
Wed, 26 Nov 2014 11:22:59 +0000 (03:22 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 3 Mar 2015 23:55:14 +0000 (15:55 -0800)
include/caffe/common_layers.hpp
src/caffe/layers/slice_layer.cpp
src/caffe/layers/slice_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_slice_layer.cpp

index 114f24a..4e47e55 100644 (file)
@@ -450,11 +450,9 @@ class SliceLayer : public Layer<Dtype> {
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
   int count_;
-  int num_;
-  int channels_;
-  int height_;
-  int width_;
-  int slice_dim_;
+  int num_slices_;
+  int slice_size_;
+  int slice_axis_;
   vector<int> slice_point_;
 };
 
index 46c3acd..e4418c9 100644 (file)
@@ -11,9 +11,8 @@ template <typename Dtype>
 void SliceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   const SliceParameter& slice_param = this->layer_param_.slice_param();
-  slice_dim_ = slice_param.slice_dim();
-  CHECK_GE(slice_dim_, 0);
-  CHECK_LE(slice_dim_, 1) << "Can only slice num and channels";
+  CHECK(!(slice_param.has_axis() && slice_param.has_slice_dim()))
+      << "Either axis or slice_dim should be specified; not both.";
   slice_point_.clear();
   std::copy(slice_param.slice_point().begin(),
       slice_param.slice_point().end(),
@@ -23,18 +22,27 @@ void SliceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void SliceLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  count_ = 0;
-  num_ = bottom[0]->num();
-  channels_ = bottom[0]->channels();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
+  const int num_axes = bottom[0]->num_axes();
+  const SliceParameter& slice_param = this->layer_param_.slice_param();
+  if (slice_param.has_slice_dim()) {
+    slice_axis_ = static_cast<int>(slice_param.slice_dim());
+    // Don't allow negative indexing for slice_dim, a uint32 -- almost
+    // certainly unintended.
+    CHECK_GE(slice_axis_, 0) << "casting slice_dim from uint32 to int32 "
+        << "produced negative result; slice_dim must satisfy "
+        << "0 <= slice_dim < " << kMaxBlobAxes;
+    CHECK_LT(slice_axis_, num_axes) << "slice_dim out of range.";
+  } else {
+    slice_axis_ = bottom[0]->CanonicalAxisIndex(slice_param.axis());
+  }
+  vector<int> top_shape = bottom[0]->shape();
+  const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
+  num_slices_ = bottom[0]->count(0, slice_axis_);
+  slice_size_ = bottom[0]->count(slice_axis_ + 1);
+  int count = 0;
   if (slice_point_.size() != 0) {
     CHECK_EQ(slice_point_.size(), top.size() - 1);
-    if (slice_dim_ == 0) {
-      CHECK_LE(top.size(), num_);
-    } else {
-      CHECK_LE(top.size(), channels_);
-    }
+    CHECK_LE(top.size(), bottom_slice_axis);
     int prev = 0;
     vector<int> slices;
     for (int i = 0; i < slice_point_.size(); ++i) {
@@ -42,94 +50,64 @@ void SliceLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       slices.push_back(slice_point_[i] - prev);
       prev = slice_point_[i];
     }
-    if (slice_dim_ == 0) {
-      slices.push_back(num_ - prev);
-      for (int i = 0; i < top.size(); ++i) {
-        top[i]->Reshape(slices[i], channels_, height_, width_);
-        count_ += top[i]->count();
-      }
-    } else {
-      slices.push_back(channels_ - prev);
-      for (int i = 0; i < top.size(); ++i) {
-        top[i]->Reshape(num_, slices[i], height_, width_);
-        count_ += top[i]->count();
-      }
+    slices.push_back(bottom_slice_axis - prev);
+    for (int i = 0; i < top.size(); ++i) {
+      top_shape[slice_axis_] = slices[i];
+      top[i]->Reshape(top_shape);
+      count += top[i]->count();
     }
   } else {
-    if (slice_dim_ == 0) {
-      CHECK_EQ(num_ % top.size(), 0)
-          << "Number of top blobs (" << top.size() << ") "
-          << "should evenly divide input num ( " << num_ << ")";
-      num_ = num_ / top.size();
-    } else {
-      CHECK_EQ(channels_ % top.size(), 0)
-          << "Number of top blobs (" << top.size() << ") "
-          << "should evenly divide input channels ( " << channels_ << ")";
-      channels_ = channels_ / top.size();
-    }
+    CHECK_EQ(bottom_slice_axis % top.size(), 0)
+        << "Number of top blobs (" << top.size() << ") should evenly "
+        << "divide input slice axis (" << bottom_slice_axis << ")";
+    top_shape[slice_axis_] = bottom_slice_axis / top.size();
     for (int i = 0; i < top.size(); ++i) {
-      top[i]->Reshape(num_, channels_, height_, width_);
-      count_ += top[i]->count();
+      top[i]->Reshape(top_shape);
+      count += top[i]->count();
     }
   }
-  CHECK_EQ(count_, bottom[0]->count());
+  CHECK_EQ(count, bottom[0]->count());
 }
 
 template <typename Dtype>
 void SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  const Dtype* bottom_data = bottom[0]->mutable_cpu_data();
-  if (slice_dim_ == 0) {
-    int offset_num = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      Dtype* top_data = blob->mutable_cpu_data();
-      caffe_copy(blob->count(), bottom_data + bottom[0]->offset(offset_num),
-                 top_data);
-      offset_num += blob->num();
+  int offset_slice_axis = 0;
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
+  for (int i = 0; i < top.size(); ++i) {
+    Dtype* top_data = top[i]->mutable_cpu_data();
+    const int top_slice_axis = top[i]->shape(slice_axis_);
+    for (int n = 0; n < num_slices_; ++n) {
+      const int top_offset = n * top_slice_axis * slice_size_;
+      const int bottom_offset =
+          (n * bottom_slice_axis + offset_slice_axis) * slice_size_;
+      caffe_copy(top_slice_axis * slice_size_,
+          bottom_data + bottom_offset, top_data + top_offset);
     }
-  } else if (slice_dim_ == 1) {
-    int offset_channel = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      Dtype* top_data = blob->mutable_cpu_data();
-      const int num_elem = blob->channels() * blob->height() * blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_copy(num_elem, bottom_data + bottom[0]->offset(n, offset_channel),
-                   top_data + blob->offset(n));
-      }
-      offset_channel += blob->channels();
-    }
-  }  // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
+    offset_slice_axis += top_slice_axis;
+  }
 }
 
 template <typename Dtype>
 void SliceLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
   if (!propagate_down[0]) { return; }
+  int offset_slice_axis = 0;
   Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
-  if (slice_dim_ == 0) {
-    int offset_num = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      const Dtype* top_diff = blob->cpu_diff();
-      caffe_copy(blob->count(), top_diff,
-                 bottom_diff + bottom[0]->offset(offset_num));
-      offset_num += blob->num();
+  const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
+  for (int i = 0; i < top.size(); ++i) {
+    const Dtype* top_diff = top[i]->cpu_diff();
+    const int top_slice_axis = top[i]->shape(slice_axis_);
+    for (int n = 0; n < num_slices_; ++n) {
+      const int top_offset = n * top_slice_axis * slice_size_;
+      const int bottom_offset =
+          (n * bottom_slice_axis + offset_slice_axis) * slice_size_;
+      caffe_copy(top_slice_axis * slice_size_,
+          top_diff + top_offset, bottom_diff + bottom_offset);
     }
-  } else if (slice_dim_ == 1) {
-    int offset_channel = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      const Dtype* top_diff = blob->cpu_diff();
-      const int num_elem = blob->channels() * blob->height() * blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_copy(num_elem, top_diff + blob->offset(n),
-                   bottom_diff + bottom[0]->offset(n, offset_channel));
-      }
-      offset_channel += blob->channels();
-    }
-  }  // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
+    offset_slice_axis += top_slice_axis;
+  }
 }
 
 #ifdef CPU_ONLY
index b5c5e61..e6e6567 100644 (file)
@@ -9,58 +9,42 @@ namespace caffe {
 template <typename Dtype>
 void SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  const Dtype* bottom_data = bottom[0]->mutable_gpu_data();
-  if (slice_dim_ == 0) {
-    int offset_num = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      Dtype* top_data = blob->mutable_gpu_data();
-      caffe_copy(blob->count(), bottom_data + bottom[0]->offset(offset_num),
-                 top_data);
-      offset_num += blob->num();
+  int offset_slice_axis = 0;
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
+  for (int i = 0; i < top.size(); ++i) {
+    Dtype* top_data = top[i]->mutable_gpu_data();
+    const int top_slice_axis = top[i]->shape(slice_axis_);
+    for (int n = 0; n < num_slices_; ++n) {
+      const int top_offset = n * top_slice_axis * slice_size_;
+      const int bottom_offset =
+          (n * bottom_slice_axis + offset_slice_axis) * slice_size_;
+      caffe_copy(top_slice_axis * slice_size_,
+          bottom_data + bottom_offset, top_data + top_offset);
     }
-  } else if (slice_dim_ == 1) {
-    int offset_channel = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      Dtype* top_data = blob->mutable_gpu_data();
-      const int num_elem = blob->channels() * blob->height() * blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_copy(num_elem, bottom_data + bottom[0]->offset(n, offset_channel),
-                   top_data + blob->offset(n));
-      }
-      offset_channel += blob->channels();
-    }
-  }  // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
+    offset_slice_axis += top_slice_axis;
+  }
 }
 
 template <typename Dtype>
 void SliceLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
   if (!propagate_down[0]) { return; }
+  int offset_slice_axis = 0;
   Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
-  if (slice_dim_ == 0) {
-    int offset_num = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      const Dtype* top_diff = blob->gpu_diff();
-      caffe_copy(blob->count(), top_diff,
-                 bottom_diff + bottom[0]->offset(offset_num));
-      offset_num += blob->num();
-    }
-  } else if (slice_dim_ == 1) {
-    int offset_channel = 0;
-    for (int i = 0; i < top.size(); ++i) {
-      Blob<Dtype>* blob = top[i];
-      const Dtype* top_diff = blob->gpu_diff();
-      const int num_elem = blob->channels() * blob->height() * blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_copy(num_elem, top_diff + blob->offset(n),
-                   bottom_diff +  bottom[0]->offset(n, offset_channel));
-      }
-      offset_channel += blob->channels();
+  const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
+  for (int i = 0; i < top.size(); ++i) {
+    const Dtype* top_diff = top[i]->gpu_diff();
+    const int top_slice_axis = top[i]->shape(slice_axis_);
+    for (int n = 0; n < num_slices_; ++n) {
+      const int top_offset = n * top_slice_axis * slice_size_;
+      const int bottom_offset =
+          (n * bottom_slice_axis + offset_slice_axis) * slice_size_;
+      caffe_copy(top_slice_axis * slice_size_,
+          top_diff + top_offset, bottom_diff + bottom_offset);
     }
-  }  // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
+    offset_slice_axis += top_slice_axis;
+  }
 }
 
 INSTANTIATE_LAYER_GPU_FUNCS(SliceLayer);
index 7a4ecf9..7783a78 100644 (file)
@@ -674,12 +674,14 @@ message SigmoidParameter {
 
 // Message that stores parameters used by SliceLayer
 message SliceParameter {
-  // SliceLayer needs to know which dimension to slice across.
-  // Currently, SliceLayer only supports slicing across num (dim 0)
-  // and channels (dim 1).
-  // By default, SliceLayer slices across channels.
-  optional uint32 slice_dim = 1 [default = 1];
+  // The axis along which to slice -- may be negative to index from the end
+  // (e.g., -1 for the last axis).
+  // By default, SliceLayer concatenates blobs along the "channels" axis (1).
+  optional int32 axis = 3 [default = 1];
   repeated uint32 slice_point = 2;
+
+  // DEPRECATED: alias for "axis" -- does not support negative indexing.
+  optional uint32 slice_dim = 1 [default = 1];
 }
 
 // Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
index 395be28..ccd0364 100644 (file)
@@ -62,7 +62,7 @@ TYPED_TEST_CASE(SliceLayerTest, TestDtypesAndDevices);
 TYPED_TEST(SliceLayerTest, TestSetupNum) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  layer_param.mutable_slice_param()->set_slice_dim(0);
+  layer_param.mutable_slice_param()->set_axis(0);
   SliceLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_1_);
   EXPECT_EQ(this->blob_bottom_->num(), 3 * this->blob_top_0_->num());
@@ -91,7 +91,7 @@ TYPED_TEST(SliceLayerTest, TestSetupChannels) {
 TYPED_TEST(SliceLayerTest, TestSliceAcrossNum) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  layer_param.mutable_slice_param()->set_slice_dim(0);
+  layer_param.mutable_slice_param()->set_axis(0);
   SliceLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_0_);
   const int top_num = this->blob_bottom_->num() / 2;
@@ -166,7 +166,7 @@ TYPED_TEST(SliceLayerTest, TestGradientAcrossNum) {
   // Gradient checks are slow; reduce blob size.
   this->ReduceBottomBlobSize();
   LayerParameter layer_param;
-  layer_param.mutable_slice_param()->set_slice_dim(0);
+  layer_param.mutable_slice_param()->set_axis(0);
   SliceLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,