Add clip layer
authorHarm Berntsen <harm.berntsen@nedap.com>
Mon, 18 Jan 2016 10:41:14 +0000 (11:41 +0100)
committerWook Song <wook16.song@samsung.com>
Thu, 23 Jan 2020 13:50:47 +0000 (22:50 +0900)
include/caffe/layers/clip_layer.hpp [new file with mode: 0644]
src/caffe/layer_factory.cpp
src/caffe/layers/clip_layer.cpp [new file with mode: 0644]
src/caffe/layers/clip_layer.cu [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_neuron_layer.cpp

diff --git a/include/caffe/layers/clip_layer.hpp b/include/caffe/layers/clip_layer.hpp
new file mode 100644 (file)
index 0000000..2788193
--- /dev/null
@@ -0,0 +1,75 @@
+#ifndef CAFFE_CLIP_LAYER_HPP_
+#define CAFFE_CLIP_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+
+namespace caffe {
+
+/**
+ * @brief Clip: @f$ y = \max(min, \min(max, x)) @f$.
+ */
+template <typename Dtype>
+class ClipLayer : public NeuronLayer<Dtype> {
+ public:
+  /**
+   * @param param provides ClipParameter clip_param,
+   *     with ClipLayer options:
+   *   - min
+   *   - max
+   */
+  explicit ClipLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+
+  virtual inline const char* type() const { return "Clip"; }
+
+ protected:
+  /**
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$
+   * @param top output Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the computed outputs @f$
+   *        y = \max(min, \min(max, x))
+   *      @f$
+   */
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  /**
+   * @brief Computes the error gradient w.r.t. the clipped inputs.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+   *      with respect to computed outputs @f$ y @f$
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$; Backward fills their diff with
+   *      gradients @f$
+   *        \frac{\partial E}{\partial x} = \left\{
+   *        \begin{array}{lr}
+   *            0 & \mathrm{if} \; x < min \vee x > max \\
+   *            \frac{\partial E}{\partial y} & \mathrm{if} \; x \ge min \wedge x \le max
+   *        \end{array} \right.
+   *      @f$
+   */
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_CLIP_LAYER_HPP_
index 9f9026b..d998443 100644 (file)
@@ -7,6 +7,7 @@
 
 #include "caffe/layer.hpp"
 #include "caffe/layer_factory.hpp"
+#include "caffe/layers/clip_layer.hpp"
 #include "caffe/layers/conv_layer.hpp"
 #include "caffe/layers/deconv_layer.hpp"
 #include "caffe/layers/lrn_layer.hpp"
diff --git a/src/caffe/layers/clip_layer.cpp b/src/caffe/layers/clip_layer.cpp
new file mode 100644 (file)
index 0000000..7638701
--- /dev/null
@@ -0,0 +1,50 @@
+#include <algorithm>
+#include <vector>
+#include "caffe/layers/clip_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  const int count = bottom[0]->count();
+
+  Dtype min = this->layer_param_.clip_param().min();
+  Dtype max = this->layer_param_.clip_param().max();
+
+  for (int i = 0; i < count; ++i) {
+    top_data[i] = std::max(min, std::min(bottom_data[i], max));
+  }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = bottom[0]->cpu_data();
+    const Dtype* top_diff = top[0]->cpu_diff();
+    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+    const int count = bottom[0]->count();
+
+    Dtype min = this->layer_param_.clip_param().min();
+    Dtype max = this->layer_param_.clip_param().max();
+
+    for (int i = 0; i < count; ++i) {
+      bottom_diff[i] = top_diff[i] * (
+              bottom_data[i] >= min && bottom_data[i] <= max);
+    }
+  }
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(ClipLayer);
+#endif
+
+INSTANTIATE_CLASS(ClipLayer);
+REGISTER_LAYER_CLASS(Clip);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/clip_layer.cu b/src/caffe/layers/clip_layer.cu
new file mode 100644 (file)
index 0000000..f780447
--- /dev/null
@@ -0,0 +1,66 @@
+#include <vector>
+#include "caffe/layers/clip_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+__global__ void ClipForward(const int n, const float* in, float* out,
+    float p_min, float p_max) {
+  CUDA_KERNEL_LOOP(index, n) {
+    out[index] = fmaxf(p_min, fminf(in[index], p_max));
+  }
+}
+
+__global__ void ClipForward(const int n, const double* in, double* out,
+    double p_min, double p_max) {
+  CUDA_KERNEL_LOOP(index, n) {
+    out[index] = fmax(p_min, fmin(in[index], p_max));
+  }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  const int count = bottom[0]->count();
+  Dtype p_min = this->layer_param_.clip_param().min();
+  Dtype p_max = this->layer_param_.clip_param().max();
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+      count, bottom_data, top_data, p_min, p_max);
+  CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void ClipBackward(const int n, const Dtype* in_diff,
+    const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
+  CUDA_KERNEL_LOOP(index, n) {
+    out_diff[index] = in_diff[index] * (
+            in_data[index] >= p_min && in_data[index] <= p_max);
+  }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = bottom[0]->gpu_data();
+    const Dtype* top_diff = top[0]->gpu_diff();
+    Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+    const int count = bottom[0]->count();
+    Dtype p_min = this->layer_param_.clip_param().min();
+    Dtype p_max = this->layer_param_.clip_param().max();
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+        count, top_diff, bottom_data, bottom_diff, p_min, p_max);
+    CUDA_POST_KERNEL_CHECK;
+  }
+}
+
+
+INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);
+
+
+}  // namespace caffe
index f784aa9..5c235c6 100644 (file)
@@ -322,7 +322,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 148 (last added: swish_param)
+// LayerParameter next available layer-specific ID: 149 (last added: clip_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -378,6 +378,7 @@ message LayerParameter {
   optional ArgMaxParameter argmax_param = 103;
   optional BatchNormParameter batch_norm_param = 139;
   optional BiasParameter bias_param = 141;
+  optional ClipParameter clip_param = 148;
   optional ConcatParameter concat_param = 104;
   optional ContrastiveLossParameter contrastive_loss_param = 105;
   optional ConvolutionParameter convolution_param = 106;
@@ -505,6 +506,12 @@ message ArgMaxParameter {
   optional int32 axis = 3;
 }
 
+// Message that stores parameters used by ClipLayer
+message ClipParameter {
+  required float min = 1;
+  required float max = 2;
+}
+
 message ConcatParameter {
   // The axis along which to concatenate -- may be negative to index from the
   // end (e.g., -1 for the last axis).  Other axes must have the
index 83d80fc..5865e08 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "caffe/layers/absval_layer.hpp"
 #include "caffe/layers/bnll_layer.hpp"
+#include "caffe/layers/clip_layer.hpp"
 #include "caffe/layers/dropout_layer.hpp"
 #include "caffe/layers/elu_layer.hpp"
 #include "caffe/layers/exp_layer.hpp"
@@ -206,6 +207,38 @@ TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
       this->blob_top_vec_);
 }
 
+TYPED_TEST(NeuronLayerTest, TestClip) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "clip_param { min: -1, max: 2 }", &layer_param));
+  ClipLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], -1);
+    EXPECT_LE(top_data[i], 2);
+    EXPECT_TRUE(bottom_data[i] > -1 || top_data[i] == -1);
+    EXPECT_TRUE(bottom_data[i] < 2 || top_data[i] == 2);
+    EXPECT_TRUE(!(bottom_data[i] >= -1 && bottom_data[i] <= 2)
+            || top_data[i] == bottom_data[i]);
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestClipGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "clip_param { min: -1, max: 2 }", &layer_param));
+  ClipLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
 TYPED_TEST(NeuronLayerTest, TestReLU) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;