FlattenLayer gets a FlattenParameter with an axis, end_axis
authorJeff Donahue <jeff.donahue@gmail.com>
Thu, 1 Jan 2015 02:02:12 +0000 (18:02 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 3 Jun 2015 02:35:45 +0000 (19:35 -0700)
src/caffe/layers/flatten_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_flatten_layer.cpp

index 745f271..f7e5c9c 100644 (file)
@@ -9,9 +9,19 @@ namespace caffe {
 template <typename Dtype>
 void FlattenLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  vector<int> top_shape(2);
-  top_shape[0] = bottom[0]->num();
-  top_shape[1] = bottom[0]->count() / bottom[0]->num();
+  const int start_axis = bottom[0]->CanonicalAxisIndex(
+      this->layer_param_.flatten_param().axis());
+  const int end_axis = bottom[0]->CanonicalAxisIndex(
+      this->layer_param_.flatten_param().end_axis());
+  vector<int> top_shape;
+  for (int i = 0; i < start_axis; ++i) {
+    top_shape.push_back(bottom[0]->shape(i));
+  }
+  const int flattened_dim = bottom[0]->count(start_axis, end_axis + 1);
+  top_shape.push_back(flattened_dim);
+  for (int i = end_axis + 1; i < bottom[0]->num_axes(); ++i) {
+    top_shape.push_back(bottom[0]->shape(i));
+  }
   top[0]->Reshape(top_shape);
   CHECK_EQ(top[0]->count(), bottom[0]->count());
 }
index 619642f..f79cf80 100644 (file)
@@ -269,7 +269,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 135 (last added: log_param)
+// LayerParameter next available layer-specific ID: 136 (last added: flatten_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -326,6 +326,7 @@ message LayerParameter {
   optional DummyDataParameter dummy_data_param = 109;
   optional EltwiseParameter eltwise_param = 110;
   optional ExpParameter exp_param = 111;
+  optional FlattenParameter flatten_param = 135;
   optional HDF5DataParameter hdf5_data_param = 112;
   optional HDF5OutputParameter hdf5_output_param = 113;
   optional HingeLossParameter hinge_loss_param = 114;
@@ -533,6 +534,19 @@ message ExpParameter {
   optional float shift = 3 [default = 0.0];
 }
 
+/// Message that stores parameters used by FlattenLayer
+message FlattenParameter {
+  // The first axis to flatten: all preceding axes are retained in the output.
+  // May be negative to index from the end (e.g., -1 for the last axis).
+  optional int32 axis = 1 [default = 1];
+
+  // The last axis to flatten: all following axes are retained in the output.
+  // May be negative to index from the end (e.g., the default -1 for the last
+  // axis).
+  optional int32 end_axis = 2 [default = -1];
+}
+
+// Message that stores parameters used by HDF5DataLayer
 message HDF5DataParameter {
   // Specify the data source.
   optional string source = 1;
index 3042d29..7b6757c 100644 (file)
@@ -42,13 +42,48 @@ TYPED_TEST(FlattenLayerTest, TestSetup) {
   LayerParameter layer_param;
   FlattenLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
-  EXPECT_EQ(this->blob_top_->num(), 2);
-  EXPECT_EQ(this->blob_top_->channels(), 3 * 6 * 5);
-  EXPECT_EQ(this->blob_top_->height(), 1);
-  EXPECT_EQ(this->blob_top_->width(), 1);
+  ASSERT_EQ(this->blob_top_->num_axes(), 2);
+  EXPECT_EQ(this->blob_top_->shape(0), 2);
+  EXPECT_EQ(this->blob_top_->shape(1), 3 * 6 * 5);
 }
 
-TYPED_TEST(FlattenLayerTest, Test) {
+TYPED_TEST(FlattenLayerTest, TestSetupWithAxis) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_flatten_param()->set_axis(2);
+  FlattenLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 3);
+  EXPECT_EQ(this->blob_top_->shape(0), 2);
+  EXPECT_EQ(this->blob_top_->shape(1), 3);
+  EXPECT_EQ(this->blob_top_->shape(2), 6 * 5);
+}
+
+TYPED_TEST(FlattenLayerTest, TestSetupWithEndAxis) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_flatten_param()->set_end_axis(-2);
+  FlattenLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 3);
+  EXPECT_EQ(this->blob_top_->shape(0), 2);
+  EXPECT_EQ(this->blob_top_->shape(1), 3 * 6);
+  EXPECT_EQ(this->blob_top_->shape(2), 5);
+}
+
+TYPED_TEST(FlattenLayerTest, TestSetupWithStartAndEndAxis) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_flatten_param()->set_axis(0);
+  layer_param.mutable_flatten_param()->set_end_axis(-2);
+  FlattenLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 2);
+  EXPECT_EQ(this->blob_top_->shape(0), 2 * 3 * 6);
+  EXPECT_EQ(this->blob_top_->shape(1), 5);
+}
+
+TYPED_TEST(FlattenLayerTest, TestForward) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
   FlattenLayer<Dtype> layer(layer_param);
@@ -71,5 +106,4 @@ TYPED_TEST(FlattenLayerTest, TestGradient) {
       this->blob_top_vec_);
 }
 
-
 }  // namespace caffe