rebase and fix stuff, incorporate image and padding layers
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 18 Mar 2014 20:40:39 +0000 (13:40 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 28 Mar 2014 06:42:28 +0000 (23:42 -0700)
src/caffe/layer_factory.cpp
src/caffe/layers/images_layer.cpp
src/caffe/layers/padding_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_images_layer.cpp
src/caffe/test/test_padding_layer.cpp

index 1fc79fc..542f716 100644 (file)
@@ -9,6 +9,7 @@
 #include "caffe/vision_layers.hpp"
 #include "caffe/proto/caffe.pb.h"
 
+using std::string;
 
 namespace caffe {
 
@@ -18,6 +19,7 @@ namespace caffe {
 // but we will leave it this way for now.
 template <typename Dtype>
 Layer<Dtype>* GetLayer(const LayerParameter& param) {
+  const string& name = param.name();
   const LayerParameter_LayerType& type = param.type();
   switch (type) {
   case LayerParameter_LayerType_ACCURACY:
@@ -40,6 +42,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new HDF5DataLayer<Dtype>(param);
   case LayerParameter_LayerType_HDF5_OUTPUT:
     return new HDF5OutputLayer<Dtype>(param);
+  case LayerParameter_LayerType_IMAGE_DATA:
+    return new ImagesLayer<Dtype>(param);
   case LayerParameter_LayerType_IM2COL:
     return new Im2colLayer<Dtype>(param);
   case LayerParameter_LayerType_INFOGAIN_LOSS:
@@ -50,6 +54,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new LRNLayer<Dtype>(param);
   case LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS:
     return new MultinomialLogisticLossLayer<Dtype>(param);
+  case LayerParameter_LayerType_PADDING:
+    return new PaddingLayer<Dtype>(param);
   case LayerParameter_LayerType_POOLING:
     return new PoolingLayer<Dtype>(param);
   case LayerParameter_LayerType_RELU:
@@ -66,8 +72,10 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new TanHLayer<Dtype>(param);
   case LayerParameter_LayerType_WINDOW_DATA:
     return new WindowDataLayer<Dtype>(param);
+  case LayerParameter_LayerType_NONE:
+    LOG(FATAL) << "Layer " << name << " has unspecified type.";
   default:
-    LOG(FATAL) << "Unknown layer type: " << type;
+    LOG(FATAL) << "Layer " << name << " has unknown type " << type;
   }
   // just to suppress old compiler warnings.
   return (Layer<Dtype>*)(NULL);
index 5154f9a..63f79ca 100644 (file)
@@ -28,15 +28,16 @@ void* ImagesLayerPrefetch(void* layer_pointer) {
   CHECK(layer->prefetch_data_);
   Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
   Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
-  const Dtype scale = layer->layer_param_.scale();
-  const int batchsize = layer->layer_param_.batchsize();
-  const int cropsize = layer->layer_param_.cropsize();
-  const bool mirror = layer->layer_param_.mirror();
-  const int new_height  = layer->layer_param_.new_height();
-  const int new_width  = layer->layer_param_.new_height();
-
-  if (mirror && cropsize == 0) {
-    LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
+  ImageDataParameter image_data_param = layer->layer_param_.image_data_param();
+  const Dtype scale = image_data_param.scale();
+  const int batch_size = image_data_param.batch_size();
+  const int crop_size = image_data_param.crop_size();
+  const bool mirror = image_data_param.mirror();
+  const int new_height = image_data_param.new_height();
+  const int new_width = image_data_param.new_width();
+
+  if (mirror && crop_size == 0) {
+    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
         << "set at the same time.";
   }
   // datum scales
@@ -46,7 +47,7 @@ void* ImagesLayerPrefetch(void* layer_pointer) {
   const int size = layer->datum_size_;
   const int lines_size = layer->lines_.size();
   const Dtype* mean = layer->data_mean_.cpu_data();
-  for (int itemid = 0; itemid < batchsize; ++itemid) {
+  for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
     CHECK_GT(lines_size, layer->lines_id_);
     if (!ReadImageToDatum(layer->lines_[layer->lines_id_].first,
@@ -55,27 +56,27 @@ void* ImagesLayerPrefetch(void* layer_pointer) {
       continue;
     }
     const string& data = datum.data();
-    if (cropsize) {
+    if (crop_size) {
       CHECK(data.size()) << "Image cropping only support uint8 data";
       int h_off, w_off;
       // We only do random crop when we do training.
       if (Caffe::phase() == Caffe::TRAIN) {
         // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
-        h_off = rand() % (height - cropsize);
+        h_off = rand() % (height - crop_size);
         // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
-        w_off = rand() % (width - cropsize);
+        w_off = rand() % (width - crop_size);
       } else {
-        h_off = (height - cropsize) / 2;
-        w_off = (width - cropsize) / 2;
+        h_off = (height - crop_size) / 2;
+        w_off = (width - crop_size) / 2;
       }
       // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
       if (mirror && rand() % 2) {
         // Copy mirrored version
         for (int c = 0; c < channels; ++c) {
-          for (int h = 0; h < cropsize; ++h) {
-            for (int w = 0; w < cropsize; ++w) {
-              top_data[((itemid * channels + c) * cropsize + h) * cropsize
-                       + cropsize - 1 - w] =
+          for (int h = 0; h < crop_size; ++h) {
+            for (int w = 0; w < crop_size; ++w) {
+              top_data[((item_id * channels + c) * crop_size + h) * crop_size
+                       + crop_size - 1 - w] =
                   (static_cast<Dtype>(
                       (uint8_t)data[(c * height + h + h_off) * width
                                     + w + w_off])
@@ -87,9 +88,10 @@ void* ImagesLayerPrefetch(void* layer_pointer) {
       } else {
         // Normal copy
         for (int c = 0; c < channels; ++c) {
-          for (int h = 0; h < cropsize; ++h) {
-            for (int w = 0; w < cropsize; ++w) {
-              top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
+          for (int h = 0; h < crop_size; ++h) {
+            for (int w = 0; w < crop_size; ++w) {
+              top_data[((item_id * channels + c) * crop_size + h)
+                       * crop_size + w]
                   = (static_cast<Dtype>(
                       (uint8_t)data[(c * height + h + h_off) * width
                                     + w + w_off])
@@ -103,25 +105,25 @@ void* ImagesLayerPrefetch(void* layer_pointer) {
       // Just copy the whole data
       if (data.size()) {
         for (int j = 0; j < size; ++j) {
-          top_data[itemid * size + j] =
+          top_data[item_id * size + j] =
               (static_cast<Dtype>((uint8_t)data[j]) - mean[j]) * scale;
         }
       } else {
         for (int j = 0; j < size; ++j) {
-          top_data[itemid * size + j] =
+          top_data[item_id * size + j] =
               (datum.float_data(j) - mean[j]) * scale;
         }
       }
     }
 
-    top_label[itemid] = datum.label();
+    top_label[item_id] = datum.label();
     // go to the next iter
     layer->lines_id_++;
     if (layer->lines_id_ >= lines_size) {
       // We have reached the end. Restart from the first.
       DLOG(INFO) << "Restarting data prefetching from start.";
       layer->lines_id_ = 0;
-      if (layer->layer_param_.shuffle_images()) {
+      if (layer->layer_param_.image_data_param().shuffle()) {
         std::random_shuffle(layer->lines_.begin(), layer->lines_.end());
       }
     }
@@ -141,22 +143,22 @@ void ImagesLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 0) << "Input Layer takes no input blobs.";
   CHECK_EQ(top->size(), 2) << "Input Layer takes two blobs as output.";
-  const int new_height = this->layer_param_.new_height();
-  const int new_width = this->layer_param_.new_height();
+  const int new_height  = this->layer_param_.image_data_param().new_height();
+  const int new_width  = this->layer_param_.image_data_param().new_height();
   CHECK((new_height == 0 && new_width == 0) ||
-      (new_height > 0 && new_width > 0)) <<
-  "Current implementation requires new_height and new_width to be set"
-  "at the same time.";
+      (new_height > 0 && new_width > 0)) << "Current implementation requires "
+      "new_height and new_width to be set at the same time.";
   // Read the file with filenames and labels
-  LOG(INFO) << "Opening file " << this->layer_param_.source();
-  std::ifstream infile(this->layer_param_.source().c_str());
+  const string& source = this->layer_param_.image_data_param().source();
+  LOG(INFO) << "Opening file " << source;
+  std::ifstream infile(source.c_str());
   string filename;
   int label;
   while (infile >> filename >> label) {
     lines_.push_back(std::make_pair(filename, label));
   }
 
-  if (this->layer_param_.shuffle_images()) {
+  if (this->layer_param_.image_data_param().shuffle()) {
     // randomly shuffle data
     LOG(INFO) << "Shuffling data";
     std::random_shuffle(lines_.begin(), lines_.end());
@@ -165,9 +167,10 @@ void ImagesLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
 
   lines_id_ = 0;
   // Check if we would need to randomly skip a few data points
-  if (this->layer_param_.rand_skip()) {
+  if (this->layer_param_.image_data_param().rand_skip()) {
     // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
-    unsigned int skip = rand() % this->layer_param_.rand_skip();
+    unsigned int skip = rand() %
+        this->layer_param_.image_data_param().rand_skip();
     LOG(INFO) << "Skipping first " << skip << " data points.";
     CHECK_GT(lines_.size(), skip) << "Not enought points to skip";
     lines_id_ = skip;
@@ -177,39 +180,37 @@ void ImagesLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                          new_height, new_width, &datum));
   // image
-  int cropsize = this->layer_param_.cropsize();
-  if (cropsize > 0) {
-    (*top)[0]->Reshape(
-        this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
-    prefetch_data_.reset(new Blob<Dtype>(
-        this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
+  const int crop_size = this->layer_param_.image_data_param().crop_size();
+  const int batch_size = this->layer_param_.image_data_param().batch_size();
+  const string& mean_file = this->layer_param_.image_data_param().mean_file();
+  if (crop_size > 0) {
+    (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
+    prefetch_data_.reset(new Blob<Dtype>(batch_size, datum.channels(),
+                                         crop_size, crop_size));
   } else {
-    (*top)[0]->Reshape(
-        this->layer_param_.batchsize(), datum.channels(), datum.height(),
-        datum.width());
-    prefetch_data_.reset(new Blob<Dtype>(
-        this->layer_param_.batchsize(), datum.channels(), datum.height(),
-        datum.width()));
+    (*top)[0]->Reshape(batch_size, datum.channels(), datum.height(),
+                       datum.width());
+    prefetch_data_.reset(new Blob<Dtype>(batch_size, datum.channels(),
+                                         datum.height(), datum.width()));
   }
   LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
       << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
       << (*top)[0]->width();
   // label
-  (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
-  prefetch_label_.reset(
-      new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
+  (*top)[1]->Reshape(batch_size, 1, 1, 1);
+  prefetch_label_.reset(new Blob<Dtype>(batch_size, 1, 1, 1));
   // datum size
   datum_channels_ = datum.channels();
   datum_height_ = datum.height();
   datum_width_ = datum.width();
   datum_size_ = datum.channels() * datum.height() * datum.width();
-  CHECK_GT(datum_height_, cropsize);
-  CHECK_GT(datum_width_, cropsize);
+  CHECK_GT(datum_height_, crop_size);
+  CHECK_GT(datum_width_, crop_size);
   // check if we want to have mean
-  if (this->layer_param_.has_meanfile()) {
+  if (this->layer_param_.image_data_param().has_mean_file()) {
     BlobProto blob_proto;
-    LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
-    ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
+    LOG(INFO) << "Loading mean file from" << mean_file;
+    ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);
     data_mean_.FromProto(blob_proto);
     CHECK_EQ(data_mean_.num(), 1);
     CHECK_EQ(data_mean_.channels(), datum_channels_);
index 6b22638..61fc58c 100644 (file)
@@ -16,7 +16,7 @@ void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
                   "convolutions and WILL BE REMOVED. Please update your model "
                   "prototxt to replace padding layers with pad fields. "
                   "See https://github.com/BVLC/caffe/pull/128.";
-  PAD_ = this->layer_param_.pad();
+  PAD_ = this->layer_param_.padding_param().pad();
   CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
   CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
   NUM_ = bottom[0]->num();
index b69298b..90113c9 100644 (file)
@@ -92,32 +92,41 @@ message SolverState {
 message LayerParameter {
   optional string name = 1; // the layer name
 
-  // Add new LayerTypes to the enum below in lexicographical order, starting
-  // with the next available ID in the comment line above the enum.
-  // Update the next available ID when you add a new LayerType.
-  // LayerType next available ID: 21
+  // Add new LayerTypes to the enum below in lexicographical order (other than
+  // starting with NONE), starting with the next available ID in the comment
+  // line above the enum. Update the next available ID when you add a new
+  // LayerType.
+  //
+  // LayerType next available ID: 25
   enum LayerType {
-    ACCURACY = 0;
-    BNLL = 1;
-    CONCAT = 2;
-    CONVOLUTION = 3;
-    DATA = 4;
-    DROPOUT = 5;
-    EUCLIDEAN_LOSS = 6;
-    FLATTEN = 7;
-    HDF5_DATA = 8;
-    IM2COL = 9;
-    INFOGAIN_LOSS = 10;
-    INNER_PRODUCT = 11;
-    LRN = 12;
-    MULTINOMIAL_LOGISTIC_LOSS = 13;
-    POOLING = 14;
-    RELU = 15;
-    SIGMOID = 16;
-    SOFTMAX = 17;
-    SOFTMAX_LOSS = 18;
-    SPLIT = 19;
-    TANH = 20;
+    // "NONE" layer type is 0th enum element so that we don't cause confusion
+    // by defaulting to an existent LayerType (instead, should usually error if
+    // the type is unspecified).
+    NONE = 0;
+    ACCURACY = 1;
+    BNLL = 2;
+    CONCAT = 3;
+    CONVOLUTION = 4;
+    DATA = 5;
+    DROPOUT = 6;
+    EUCLIDEAN_LOSS = 7;
+    FLATTEN = 8;
+    HDF5_DATA = 9;
+    IM2COL = 10;
+    IMAGE_DATA = 11;
+    INFOGAIN_LOSS = 12;
+    INNER_PRODUCT = 13;
+    LRN = 14;
+    MULTINOMIAL_LOGISTIC_LOSS = 15;
+    PADDING = 16;
+    POOLING = 17;
+    RELU = 18;
+    SIGMOID = 19;
+    SOFTMAX = 20;
+    SOFTMAX_LOSS = 21;
+    SPLIT = 22;
+    TANH = 23;
+    WINDOW_DATA = 24;
   }
   optional LayerType type = 2; // the layer type from the enum above
 
@@ -138,10 +147,12 @@ message LayerParameter {
   optional DataParameter data_param = 10;
   optional DropoutParameter dropout_param = 11;
   optional HDF5DataParameter hdf5_data_param = 12;
-  optional InfogainLossParameter infogain_loss_param = 13;
-  optional InnerProductParameter inner_product_param = 14;
-  optional LRNParameter lrn_param = 15;
-  optional PoolingParameter pooling_param = 16;
+  optional ImageDataParameter image_data_param = 13;
+  optional InfogainLossParameter infogain_loss_param = 14;
+  optional InnerProductParameter inner_product_param = 15;
+  optional LRNParameter lrn_param = 16;
+  optional PaddingParameter padding_param = 17;
+  optional PoolingParameter pooling_param = 18;
 }
 
 // Message that stores parameters used by ConcatLayer
@@ -199,7 +210,34 @@ message HDF5DataParameter {
   optional uint32 batch_size = 2;
 }
 
-// Message that stores parameters used by InfogainLossLayer
+// Message that stores parameters used by ImageDataLayer
+message ImageDataParameter {
+  // Specify the data source.
+  optional string source = 1;
+  // For data pre-processing, we can do simple scaling and subtracting the
+  // data mean, if provided. Note that the mean subtraction is always carried
+  // out before scaling.
+  optional float scale = 2 [default = 1];
+  optional string mean_file = 3;
+  // Specify the batch size.
+  optional uint32 batch_size = 4;
+  // Specify if we would like to randomly crop an image.
+  optional uint32 crop_size = 5 [default = 0];
+  // Specify if we want to randomly mirror data.
+  optional bool mirror = 6 [default = false];
+  // The rand_skip variable is for the data layer to skip a few data points
+  // to avoid all asynchronous sgd clients to start at the same point. The skip
+  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+  // be larger than the number of keys in the leveldb.
+  optional uint32 rand_skip = 7 [default = 0];
+  // Whether or not ImageLayer should shuffle the list of files at every epoch.
+  optional bool shuffle = 8 [default = false];
+  // It will also resize images if new_height or new_width are not zero.
+  optional uint32 new_height = 9 [default = 0];
+  optional uint32 new_width = 10 [default = 0];
+}
+
+// Message that stores parameters InfogainLossLayer
 message InfogainLossParameter {
   // Specify the infogain matrix source.
   optional string source = 1;
@@ -220,6 +258,11 @@ message LRNParameter {
   optional float beta = 3 [default = 0.75]; // for local response norm
 }
 
+// Message that stores parameters used by PaddingLayer
+message PaddingParameter {
+  optional uint32 pad = 1 [default = 0]; // The padding size
+}
+
 // Message that stores parameters used by PoolingLayer
 message PoolingParameter {
   enum PoolMethod {
@@ -232,18 +275,6 @@ message PoolingParameter {
   optional uint32 stride = 3 [default = 1]; // The stride
 }
 
-// Message that stores parameters used by DropoutLayer
-message DropoutParameter {
-  optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
-}
-
-// Message that stores parameters used by LRNLayer
-message LRNParameter {
-  optional uint32 local_size = 1 [default = 5]; // for local response norm
-  optional float alpha = 2 [default = 1.]; // for local response norm
-  optional float beta = 3 [default = 0.75]; // for local response norm
-}
-
 message WindowDataParameter {
   // Fields related to detection (det_*)
   // foreground (object) overlap threshold
@@ -268,33 +299,3 @@ message WindowDataParameter {
 message HDF5OutputParameter {
   optional string file_name = 1;
 }
-
-// Message that stores parameters used by ConcatLayer
-message ConcatParameter {
-  // For ConcatLayer, one needs to specify the dimension for concatenation, and
-  // the other dimensions must be the same for all the bottom blobs.
-  // By default it will concatenate blobs along the channels dimension.
-  optional uint32 concat_dim = 1 [default = 1];
-}
-
-// Message that stores parameters used by ReshapeLayer
-message ReshapeParameter {
-  // For ReshapeLayer, one needs to specify the new dimensions.
-  optional int32 new_num = 1 [default = 0];
-  optional int32 new_channels = 2 [default = 0];
-  optional int32 new_height = 3 [default = 0];
-  optional int32 new_width = 4 [default = 0];
-}
-
-// Message that stores parameters used by ImageDataLayer
-message ImageDataParameter {
-  // Whether or not ImageLayer should shuffle the list of files at every epoch.
-  // It will also resize images if new_height or new_width are not zero.
-  optional bool shuffle_images = 64 [default = false];
-}
-
-// Message that stores parameters InfogainLossLayer
-message InfogainLossParameter {
-  // Specify the infogain matrix source.
-  optional string source = 1;
-}
index e8ed7c1..0cd1001 100644 (file)
@@ -55,9 +55,10 @@ TYPED_TEST_CASE(ImagesLayerTest, Dtypes);
 
 TYPED_TEST(ImagesLayerTest, TestRead) {
   LayerParameter param;
-  param.set_batchsize(5);
-  param.set_source(this->filename);
-  param.set_shuffle_images(false);
+  ImageDataParameter* image_data_param = param.mutable_image_data_param();
+  image_data_param->set_batch_size(5);
+  image_data_param->set_source(this->filename);
+  image_data_param->set_shuffle(false);
   ImagesLayer<TypeParam> layer(param);
   layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_data_->num(), 5);
@@ -79,11 +80,12 @@ TYPED_TEST(ImagesLayerTest, TestRead) {
 
 TYPED_TEST(ImagesLayerTest, TestResize) {
   LayerParameter param;
-  param.set_batchsize(5);
-  param.set_source(this->filename);
-  param.set_new_height(256);
-  param.set_new_width(256);
-  param.set_shuffle_images(false);
+  ImageDataParameter* image_data_param = param.mutable_image_data_param();
+  image_data_param->set_batch_size(5);
+  image_data_param->set_source(this->filename);
+  image_data_param->set_new_height(256);
+  image_data_param->set_new_width(256);
+  image_data_param->set_shuffle(false);
   ImagesLayer<TypeParam> layer(param);
   layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_data_->num(), 5);
@@ -105,9 +107,10 @@ TYPED_TEST(ImagesLayerTest, TestResize) {
 
 TYPED_TEST(ImagesLayerTest, TestShuffle) {
   LayerParameter param;
-  param.set_batchsize(5);
-  param.set_source(this->filename);
-  param.set_shuffle_images(true);
+  ImageDataParameter* image_data_param = param.mutable_image_data_param();
+  image_data_param->set_batch_size(5);
+  image_data_param->set_source(this->filename);
+  image_data_param->set_shuffle(true);
   ImagesLayer<TypeParam> layer(param);
   layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_data_->num(), 5);
index c775f3b..59b012c 100644 (file)
@@ -42,7 +42,8 @@ TYPED_TEST_CASE(PaddingLayerTest, Dtypes);
 
 TYPED_TEST(PaddingLayerTest, TestCPU) {
   LayerParameter layer_param;
-  layer_param.set_pad(1);
+  PaddingParameter* padding_param = layer_param.mutable_padding_param();
+  padding_param->set_pad(1);
   Caffe::set_mode(Caffe::CPU);
   PaddingLayer<TypeParam> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
@@ -65,7 +66,8 @@ TYPED_TEST(PaddingLayerTest, TestCPU) {
 
 TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
   LayerParameter layer_param;
-  layer_param.set_pad(1);
+  PaddingParameter* padding_param = layer_param.mutable_padding_param();
+  padding_param->set_pad(1);
   Caffe::set_mode(Caffe::CPU);
   PaddingLayer<TypeParam> layer(layer_param);
   GradientChecker<TypeParam> checker(1e-2, 1e-3);
@@ -76,7 +78,8 @@ TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
 TYPED_TEST(PaddingLayerTest, TestGPU) {
   if (CAFFE_TEST_CUDA_PROP.major >= 2) {
     LayerParameter layer_param;
-    layer_param.set_pad(1);
+    PaddingParameter* padding_param = layer_param.mutable_padding_param();
+    padding_param->set_pad(1);
     Caffe::set_mode(Caffe::GPU);
     PaddingLayer<TypeParam> layer(layer_param);
     layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
@@ -103,7 +106,8 @@ TYPED_TEST(PaddingLayerTest, TestGPU) {
 TYPED_TEST(PaddingLayerTest, TestGPUGrad) {
   if (CAFFE_TEST_CUDA_PROP.major >= 2) {
     LayerParameter layer_param;
-    layer_param.set_pad(1);
+    PaddingParameter* padding_param = layer_param.mutable_padding_param();
+    padding_param->set_pad(1);
     Caffe::set_mode(Caffe::GPU);
     PaddingLayer<TypeParam> layer(layer_param);
     GradientChecker<TypeParam> checker(1e-2, 1e-3);