SliceLayer: allow trivial operation with single top Blob
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 7 Oct 2014 18:55:54 +0000 (11:55 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Thu, 3 Sep 2015 09:03:27 +0000 (02:03 -0700)
include/caffe/common_layers.hpp
src/caffe/layers/slice_layer.cpp
src/caffe/layers/slice_layer.cu
src/caffe/test/test_slice_layer.cpp

index 8e64b3e..6d4a9e3 100644 (file)
@@ -625,7 +625,7 @@ class SliceLayer : public Layer<Dtype> {
 
   virtual inline const char* type() const { return "Slice"; }
   virtual inline int ExactNumBottomBlobs() const { return 1; }
-  virtual inline int MinTopBlobs() const { return 2; }
+  virtual inline int MinTopBlobs() const { return 1; }
 
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
index e4418c9..0a059ae 100644 (file)
@@ -67,11 +67,16 @@ void SliceLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
     }
   }
   CHECK_EQ(count, bottom[0]->count());
+  if (top.size() == 1) {
+    top[0]->ShareData(*bottom[0]);
+    top[0]->ShareDiff(*bottom[0]);
+  }
 }
 
 template <typename Dtype>
 void SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
+  if (top.size() == 1) { return; }
   int offset_slice_axis = 0;
   const Dtype* bottom_data = bottom[0]->cpu_data();
   const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
@@ -92,7 +97,7 @@ void SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void SliceLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
-  if (!propagate_down[0]) { return; }
+  if (!propagate_down[0] || top.size() == 1) { return; }
   int offset_slice_axis = 0;
   Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
   const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
index 796841d..e8dc6cd 100644 (file)
@@ -28,6 +28,7 @@ __global__ void Slice(const int nthreads, const Dtype* in_data,
 template <typename Dtype>
 void SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
+  if (top.size() == 1) { return; }
   int offset_slice_axis = 0;
   const Dtype* bottom_data = bottom[0]->gpu_data();
   const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
@@ -48,7 +49,7 @@ void SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void SliceLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
-  if (!propagate_down[0]) { return; }
+  if (!propagate_down[0] || top.size() == 1) { return; }
   int offset_slice_axis = 0;
   Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
   const int bottom_slice_axis = bottom[0]->shape(slice_axis_);
index ccd0364..2d2d0fd 100644 (file)
@@ -88,6 +88,21 @@ TYPED_TEST(SliceLayerTest, TestSetupChannels) {
   EXPECT_EQ(this->blob_bottom_->width(), this->blob_top_0_->width());
 }
 
+TYPED_TEST(SliceLayerTest, TestTrivialSlice) {
+  // Test the trivial (single output) "slice" operation --
+  // should be the identity.
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  SliceLayer<Dtype> layer(layer_param);
+  this->blob_top_vec_0_.resize(1);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_0_);
+  ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_0_->shape());
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_EQ(this->blob_bottom_->cpu_data()[i],
+              this->blob_top_0_->cpu_data()[i]);
+  }
+}
+
 TYPED_TEST(SliceLayerTest, TestSliceAcrossNum) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
@@ -161,6 +176,18 @@ TYPED_TEST(SliceLayerTest, TestSliceAcrossChannels) {
   }
 }
 
+TYPED_TEST(SliceLayerTest, TestGradientTrivial) {
+  // Test the trivial (single output) "slice" operation --
+  // should be the identity.
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  SliceLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  this->blob_top_vec_0_.resize(1);
+  checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_0_);
+}
+
 TYPED_TEST(SliceLayerTest, TestGradientAcrossNum) {
   typedef typename TypeParam::Dtype Dtype;
   // Gradient checks are slow; reduce blob size.