remove padding layer
authorJeff Donahue <jeff.donahue@gmail.com>
Fri, 21 Mar 2014 21:33:42 +0000 (14:33 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 28 Mar 2014 06:42:28 +0000 (23:42 -0700)
include/caffe/vision_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/padding_layer.cpp [deleted file]
src/caffe/layers/padding_layer.cu [deleted file]
src/caffe/proto/caffe.proto
src/caffe/test/test_padding_layer.cpp [deleted file]

index 5807938..110b272 100644 (file)
@@ -203,33 +203,6 @@ class InnerProductLayer : public Layer<Dtype> {
 
 
 template <typename Dtype>
-class PaddingLayer : public Layer<Dtype> {
- public:
-  explicit PaddingLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  unsigned int PAD_;
-  int NUM_;
-  int CHANNEL_;
-  int HEIGHT_IN_;
-  int WIDTH_IN_;
-  int HEIGHT_OUT_;
-  int WIDTH_OUT_;
-};
-
-
-template <typename Dtype>
 class LRNLayer : public Layer<Dtype> {
  public:
   explicit LRNLayer(const LayerParameter& param)
index 542f716..f86e12e 100644 (file)
@@ -54,8 +54,6 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new LRNLayer<Dtype>(param);
   case LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS:
     return new MultinomialLogisticLossLayer<Dtype>(param);
-  case LayerParameter_LayerType_PADDING:
-    return new PaddingLayer<Dtype>(param);
   case LayerParameter_LayerType_POOLING:
     return new PoolingLayer<Dtype>(param);
   case LayerParameter_LayerType_RELU:
diff --git a/src/caffe/layers/padding_layer.cpp b/src/caffe/layers/padding_layer.cpp
deleted file mode 100644 (file)
index 61fc58c..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2014 BVLC and contributors.
-
-#include <iostream>  // NOLINT(readability/streams)
-#include <vector>
-
-#include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  // DEPRECATION
-  LOG(WARNING) << "Padding layers are deprecated in favor of padding-aware "
-                  "convolutions and WILL BE REMOVED. Please update your model "
-                  "prototxt to replace padding layers with pad fields. "
-                  "See https://github.com/BVLC/caffe/pull/128.";
-  PAD_ = this->layer_param_.padding_param().pad();
-  CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
-  NUM_ = bottom[0]->num();
-  CHANNEL_ = bottom[0]->channels();
-  HEIGHT_IN_ = bottom[0]->height();
-  WIDTH_IN_ = bottom[0]->width();
-  HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2;
-  WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2;
-  (*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_);
-}
-
-template <typename Dtype>
-Dtype PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  const Dtype* bottom_data = bottom[0]->cpu_data();
-  memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count());
-  // In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range
-  for (int n = 0; n < NUM_; ++n) {
-    for (int c = 0; c < CHANNEL_; ++c) {
-      for (int h = 0; h < HEIGHT_IN_; ++h) {
-        // copy the width part
-        memcpy(
-            top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
-                * WIDTH_OUT_ + PAD_,
-            bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
-            sizeof(Dtype) * WIDTH_IN_);
-      }
-    }
-  }
-  return Dtype(0.);
-}
-
-template <typename Dtype>
-void PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  const Dtype* top_diff = top[0]->cpu_diff();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  for (int n = 0; n < NUM_; ++n) {
-    for (int c = 0; c < CHANNEL_; ++c) {
-      for (int h = 0; h < HEIGHT_IN_; ++h) {
-        // copy the width part
-        memcpy(
-            bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
-            top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
-                * WIDTH_OUT_ + PAD_,
-            sizeof(Dtype) * WIDTH_IN_);
-      }
-    }
-  }
-}
-
-INSTANTIATE_CLASS(PaddingLayer);
-
-}  // namespace caffe
diff --git a/src/caffe/layers/padding_layer.cu b/src/caffe/layers/padding_layer.cu
deleted file mode 100644 (file)
index 8023fef..0000000
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2014 BVLC and contributors.
-
-#include <iostream>  // NOLINT(readability/streams)
-#include <vector>
-
-#include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
-    const int num, const int channel, const int height_in, const int width_in,
-    const int pad) {
-  CUDA_KERNEL_LOOP(index, count) {
-    int height_out = height_in + pad + pad;
-    int width_out = width_in + pad + pad;
-    int w = index % width_in;
-    index /= width_in;
-    int h = index % height_in;
-    index /= height_in;
-    int c = index % channel;
-    index /= channel;
-    out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] =
-        in[((index * channel + c) * height_in + h) * width_in + w];
-  }
-}
-
-template <typename Dtype>
-Dtype PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0]->gpu_data();
-  Dtype* top_data = (*top)[0]->mutable_gpu_data();
-  const int count = bottom[0]->count();
-  // First, set all data to be zero for the boundary pixels
-  CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()));
-  // NOLINT_NEXT_LINE(whitespace/operators)
-  PaddingForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-      count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
-      PAD_);
-  CUDA_POST_KERNEL_CHECK;
-  return Dtype(0);
-}
-
-template <typename Dtype>
-__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
-    const int num, const int channel, const int height_in, const int width_in,
-    const int pad) {
-  CUDA_KERNEL_LOOP(index, count) {
-    int height_out = height_in + pad + pad;
-    int width_out = width_in + pad + pad;
-    int w = index % width_in;
-    index /= width_in;
-    int h = index % height_in;
-    index /= height_in;
-    int c = index % channel;
-    index /= channel;
-    out[((index * channel + c) * height_in + h) * width_in + w] =
-        in[((index * channel + c) * height_out + h + pad) *
-           width_out + pad + w];
-  }
-}
-
-template <typename Dtype>
-void PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
-    vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
-    const Dtype* top_diff = top[0]->gpu_diff();
-    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
-    const int count = (*bottom)[0]->count();
-    // NOLINT_NEXT_LINE(whitespace/operators)
-    PaddingBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-        count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
-        PAD_);
-    CUDA_POST_KERNEL_CHECK;
-  }
-}
-
-INSTANTIATE_CLASS(PaddingLayer);
-
-}  // namespace caffe
index 51ea340..34d59f9 100644 (file)
@@ -118,15 +118,14 @@ message LayerParameter {
     INNER_PRODUCT = 13;
     LRN = 14;
     MULTINOMIAL_LOGISTIC_LOSS = 15;
-    PADDING = 16;
-    POOLING = 17;
-    RELU = 18;
-    SIGMOID = 19;
-    SOFTMAX = 20;
-    SOFTMAX_LOSS = 21;
-    SPLIT = 22;
-    TANH = 23;
-    WINDOW_DATA = 24;
+    POOLING = 16;
+    RELU = 17;
+    SIGMOID = 18;
+    SOFTMAX = 19;
+    SOFTMAX_LOSS = 20;
+    SPLIT = 21;
+    TANH = 22;
+    WINDOW_DATA = 23;
   }
   optional LayerType type = 2; // the layer type from the enum above
 
@@ -151,9 +150,8 @@ message LayerParameter {
   optional InfogainLossParameter infogain_loss_param = 14;
   optional InnerProductParameter inner_product_param = 15;
   optional LRNParameter lrn_param = 16;
-  optional PaddingParameter padding_param = 17;
-  optional PoolingParameter pooling_param = 18;
-  optional WindowDataParameter window_data_param = 19;
+  optional PoolingParameter pooling_param = 17;
+  optional WindowDataParameter window_data_param = 18;
 }
 
 // Message that stores parameters used by ConcatLayer
@@ -259,11 +257,6 @@ message LRNParameter {
   optional float beta = 3 [default = 0.75]; // for local response norm
 }
 
-// Message that stores parameters used by PaddingLayer
-message PaddingParameter {
-  optional uint32 pad = 1 [default = 0]; // The padding size
-}
-
 // Message that stores parameters used by PoolingLayer
 message PoolingParameter {
   enum PoolMethod {
diff --git a/src/caffe/test/test_padding_layer.cpp b/src/caffe/test/test_padding_layer.cpp
deleted file mode 100644 (file)
index 59b012c..0000000
+++ /dev/null
@@ -1,121 +0,0 @@
-// Copyright 2014 BVLC and contributors.
-
-#include <cuda_runtime.h>
-#include <cstring>
-#include <vector>
-
-#include "gtest/gtest.h"
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/vision_layers.hpp"
-#include "caffe/test/test_gradient_check_util.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-
-namespace caffe {
-
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
-template <typename Dtype>
-class PaddingLayerTest : public ::testing::Test {
- protected:
-  PaddingLayerTest()
-      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
-        blob_top_(new Blob<Dtype>()) {
-    // fill the values
-    FillerParameter filler_param;
-    GaussianFiller<Dtype> filler(filler_param);
-    filler.Fill(this->blob_bottom_);
-    blob_bottom_vec_.push_back(blob_bottom_);
-    blob_top_vec_.push_back(blob_top_);
-  }
-  virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; }
-  Blob<Dtype>* const blob_bottom_;
-  Blob<Dtype>* const blob_top_;
-  vector<Blob<Dtype>*> blob_bottom_vec_;
-  vector<Blob<Dtype>*> blob_top_vec_;
-};
-
-typedef ::testing::Types<float, double> Dtypes;
-TYPED_TEST_CASE(PaddingLayerTest, Dtypes);
-
-TYPED_TEST(PaddingLayerTest, TestCPU) {
-  LayerParameter layer_param;
-  PaddingParameter* padding_param = layer_param.mutable_padding_param();
-  padding_param->set_pad(1);
-  Caffe::set_mode(Caffe::CPU);
-  PaddingLayer<TypeParam> layer(layer_param);
-  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  EXPECT_EQ(this->blob_top_->num(), 2);
-  EXPECT_EQ(this->blob_top_->channels(), 3);
-  EXPECT_EQ(this->blob_top_->height(), 6);
-  EXPECT_EQ(this->blob_top_->width(), 7);
-  for (int n = 0; n < 2; ++n) {
-    for (int c = 0; c < 3; ++c) {
-      for (int h = 0; h < 4; ++h) {
-        for (int w = 0; w < 5; ++w) {
-          EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
-              this->blob_top_->data_at(n, c, h + 1, w + 1));
-        }
-      }
-    }
-  }
-}
-
-TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
-  LayerParameter layer_param;
-  PaddingParameter* padding_param = layer_param.mutable_padding_param();
-  padding_param->set_pad(1);
-  Caffe::set_mode(Caffe::CPU);
-  PaddingLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-3);
-  checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
-      &(this->blob_top_vec_));
-}
-
-TYPED_TEST(PaddingLayerTest, TestGPU) {
-  if (CAFFE_TEST_CUDA_PROP.major >= 2) {
-    LayerParameter layer_param;
-    PaddingParameter* padding_param = layer_param.mutable_padding_param();
-    padding_param->set_pad(1);
-    Caffe::set_mode(Caffe::GPU);
-    PaddingLayer<TypeParam> layer(layer_param);
-    layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
-    layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-    EXPECT_EQ(this->blob_top_->num(), 2);
-    EXPECT_EQ(this->blob_top_->channels(), 3);
-    EXPECT_EQ(this->blob_top_->height(), 6);
-    EXPECT_EQ(this->blob_top_->width(), 7);
-    for (int n = 0; n < 2; ++n) {
-      for (int c = 0; c < 3; ++c) {
-        for (int h = 0; h < 4; ++h) {
-          for (int w = 0; w < 5; ++w) {
-            EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
-                this->blob_top_->data_at(n, c, h + 1, w + 1));
-          }
-        }
-      }
-    }
-  } else {
-    LOG(ERROR) << "Skipping test (gpu version too low).";
-  }
-}
-
-TYPED_TEST(PaddingLayerTest, TestGPUGrad) {
-  if (CAFFE_TEST_CUDA_PROP.major >= 2) {
-    LayerParameter layer_param;
-    PaddingParameter* padding_param = layer_param.mutable_padding_param();
-    padding_param->set_pad(1);
-    Caffe::set_mode(Caffe::GPU);
-    PaddingLayer<TypeParam> layer(layer_param);
-    GradientChecker<TypeParam> checker(1e-2, 1e-3);
-    checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
-        &(this->blob_top_vec_));
-  } else {
-    LOG(ERROR) << "Skipping test (gpu version too low).";
-  }
-}
-
-}  // namespace caffe