Added Swish layer (#6002)
authorMikhail Antonenka <anmikh@users.noreply.github.com>
Sat, 17 Mar 2018 15:26:40 +0000 (18:26 +0300)
committerWook Song <wook16.song@samsung.com>
Thu, 23 Jan 2020 13:50:38 +0000 (22:50 +0900)
* added swish layer (cpu)

* swish layer: added tests

* swish layer: optimized backpropogation

* swish layer: added cuda implementation

* swish layer: added beta parameter

* swish layer: incorporated sigmoid layer

* swish layer: fix comment of last added parameter

* swish layer: added REGISTER_LAYER_CLASS

include/caffe/layers/swish_layer.hpp [new file with mode: 0644]
src/caffe/layers/swish_layer.cpp [new file with mode: 0644]
src/caffe/layers/swish_layer.cu [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_neuron_layer.cpp

diff --git a/include/caffe/layers/swish_layer.hpp b/include/caffe/layers/swish_layer.hpp
new file mode 100644 (file)
index 0000000..d538ff6
--- /dev/null
@@ -0,0 +1,96 @@
+#ifndef CAFFE_SWISH_LAYER_HPP_
+#define CAFFE_SWISH_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/sigmoid_layer.hpp"
+
+namespace caffe {
+
+/**
+ * @brief Swish non-linearity @f$ y = x \sigma (\beta x) @f$.
+ *        A novel activation function that tends to work better than ReLU [1].
+ *
+ * [1] Prajit Ramachandran, Barret Zoph, Quoc V. Le. "Searching for
+ *     Activation Functions". arXiv preprint arXiv:1710.05941v2 (2017).
+ */
+template <typename Dtype>
+class SwishLayer : public NeuronLayer<Dtype> {
+ public:
+  /**
+   * @param param provides SwishParameter swish_param,
+   *     with SwishLayer options:
+   *   - beta (\b optional, default 1).
+   *     the value @f$ \beta @f$ in the @f$ y = x \sigma (\beta x) @f$.
+   */
+  explicit SwishLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param),
+        sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
+        sigmoid_input_(new Blob<Dtype>()),
+        sigmoid_output_(new Blob<Dtype>()) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  virtual inline const char* type() const { return "Swish"; }
+
+ protected:
+  /**
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$
+   * @param top output Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the computed outputs @f$
+   *        y = x \sigma (\beta x)
+   *      @f$.
+   */
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  /**
+   * @brief Computes the error gradient w.r.t. the sigmoid inputs.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+   *      with respect to computed outputs @f$ y @f$
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$; Backward fills their diff with
+   *      gradients @f$
+   *        \frac{\partial E}{\partial x}
+   *            = \frac{\partial E}{\partial y}(\beta y +
+   *              \sigma (\beta x)(1 - \beta y))
+   *      @f$ if propagate_down[0]
+   */
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  /// The internal SigmoidLayer
+  shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
+  /// sigmoid_input_ stores the input of the SigmoidLayer.
+  shared_ptr<Blob<Dtype> > sigmoid_input_;
+  /// sigmoid_output_ stores the output of the SigmoidLayer.
+  shared_ptr<Blob<Dtype> > sigmoid_output_;
+  /// bottom vector holder to call the underlying SigmoidLayer::Forward
+  vector<Blob<Dtype>*> sigmoid_bottom_vec_;
+  /// top vector holder to call the underlying SigmoidLayer::Forward
+  vector<Blob<Dtype>*> sigmoid_top_vec_;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_SWISH_LAYER_HPP_
diff --git a/src/caffe/layers/swish_layer.cpp b/src/caffe/layers/swish_layer.cpp
new file mode 100644 (file)
index 0000000..2893567
--- /dev/null
@@ -0,0 +1,68 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+  sigmoid_bottom_vec_.clear();
+  sigmoid_bottom_vec_.push_back(sigmoid_input_.get());
+  sigmoid_top_vec_.clear();
+  sigmoid_top_vec_.push_back(sigmoid_output_.get());
+  sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  NeuronLayer<Dtype>::Reshape(bottom, top);
+  sigmoid_input_->ReshapeLike(*bottom[0]);
+  sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  const int count = bottom[0]->count();
+  Dtype beta = this->layer_param_.swish_param().beta();
+  caffe_copy(count, bottom_data, sigmoid_input_data);
+  caffe_scal(count, beta, sigmoid_input_data);
+  sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+  caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  if (propagate_down[0]) {
+    const Dtype* top_data = top[0]->cpu_data();
+    const Dtype* top_diff = top[0]->cpu_diff();
+    const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+    const int count = bottom[0]->count();
+    Dtype beta = this->layer_param_.swish_param().beta();
+    for (int i = 0; i < count; ++i) {
+      const Dtype swish_x = top_data[i];
+      bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i]
+          * (1. - beta * swish_x));
+    }
+  }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(SwishLayer);
+#endif
+
+INSTANTIATE_CLASS(SwishLayer);
+REGISTER_LAYER_CLASS(Swish);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/swish_layer.cu b/src/caffe/layers/swish_layer.cu
new file mode 100644 (file)
index 0000000..c4fef53
--- /dev/null
@@ -0,0 +1,54 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  const int count = bottom[0]->count();
+  Dtype beta = this->layer_param_.swish_param().beta();
+  caffe_copy(count, bottom_data, sigmoid_input_data);
+  caffe_gpu_scal(count, beta, sigmoid_input_data);
+  sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+  caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data);
+}
+
+template <typename Dtype>
+__global__ void SwishBackward(const int n, const Dtype* in_diff,
+    const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff,
+    const Dtype beta) {
+  CUDA_KERNEL_LOOP(index, n) {
+    const Dtype swish_x = out_data[index];
+    out_diff[index] = in_diff[index] * (beta * swish_x
+        + sigmoid_output_data[index] * (1 - beta * swish_x));
+  }
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  if (propagate_down[0]) {
+    const Dtype* top_data = top[0]->gpu_data();
+    const Dtype* top_diff = top[0]->gpu_diff();
+    const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+    Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+    const int count = bottom[0]->count();
+    Dtype beta = this->layer_param_.swish_param().beta();
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    SwishBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta);
+    CUDA_POST_KERNEL_CHECK;
+  }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer);
+
+}  // namespace caffe
index 22764ab..b9bb3f4 100644 (file)
@@ -322,7 +322,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
+// LayerParameter next available layer-specific ID: 148 (last added: swish_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -415,6 +415,7 @@ message LayerParameter {
   optional SoftmaxParameter softmax_param = 125;
   optional SPPParameter spp_param = 132;
   optional SliceParameter slice_param = 126;
+  optional SwishParameter swish_param = 147;
   optional TanHParameter tanh_param = 127;
   optional ThresholdParameter threshold_param = 128;
   optional TileParameter tile_param = 138;
@@ -1156,6 +1157,15 @@ message SoftmaxParameter {
   optional int32 axis = 2 [default = 1];
 }
 
+// Message that stores parameters used by SwishLayer
+message SwishParameter {
+  // Beta parameter for the Swish activation function
+  // Described in:
+  // Prajit Ramachandran, Barret Zoph, Quoc V. Le. (2017). Searching for
+  // Activation Functions. https://arxiv.org/abs/1710.05941v2
+  optional float beta = 1 [default = 1];
+}
+
 message TanHParameter {
   enum Engine {
     DEFAULT = 0;
index 180871a..83d80fc 100644 (file)
@@ -19,6 +19,7 @@
 #include "caffe/layers/prelu_layer.hpp"
 #include "caffe/layers/relu_layer.hpp"
 #include "caffe/layers/sigmoid_layer.hpp"
+#include "caffe/layers/swish_layer.hpp"
 #include "caffe/layers/tanh_layer.hpp"
 #include "caffe/layers/threshold_layer.hpp"
 
@@ -344,6 +345,84 @@ TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) {
       this->blob_top_vec_);
 }
 
+TYPED_TEST(NeuronLayerTest, TestSwish) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  SwishLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-bottom_data[i])));
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBeta) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "swish_param { beta: 1.5 }", &layer_param));
+  SwishLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-1.5 *
+        bottom_data[i])));
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinear) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "swish_param { beta: 0.0 }", &layer_param));
+  SwishLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / 2.0);
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  SwishLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBetaGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "swish_param { beta: 1.5 }", &layer_param));
+  SwishLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinearGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "swish_param { beta: 0.0 }", &layer_param));
+  SwishLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
 TYPED_TEST(NeuronLayerTest, TestTanH) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;