Added absolute value layer, useful for implementation of siamese networks!
authorAlireza Shafaei <alireza@shafaei.net>
Sun, 10 Aug 2014 05:44:12 +0000 (22:44 -0700)
committerAlireza Shafaei <alireza@shafaei.net>
Thu, 14 Aug 2014 17:06:34 +0000 (10:06 -0700)
This commit also replaces the default caffe_fabs with MKL/non-MKL implementation of Abs.

include/caffe/neuron_layers.hpp
include/caffe/util/math_functions.hpp
include/caffe/util/mkl_alternate.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/absval_layer.cpp [new file with mode: 0644]
src/caffe/layers/absval_layer.cu [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_math_functions.cpp
src/caffe/test/test_neuron_layer.cpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.cu

index 20f7f6d..d17120f 100644 (file)
@@ -38,6 +38,36 @@ class NeuronLayer : public Layer<Dtype> {
   virtual inline int ExactNumTopBlobs() const { return 1; }
 };
 
+/* AbsVal Layer
+  y = |x|
+
+  y' = 1    if x > 0
+     = -1   if x < 0
+*/
+template <typename Dtype>
+class AbsValLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit AbsValLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual inline LayerParameter_LayerType type() const {
+    return LayerParameter_LayerType_ABSVAL;
+  }
+  virtual inline int ExactNumBottomBlobs() const { return 1; }
+  virtual inline int ExactNumTopBlobs() const { return 1; }
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
 /* BNLLLayer
 
   y = x + log(1 + exp(-x))  if x > 0
index 90a1a86..6a608d5 100644 (file)
@@ -89,6 +89,9 @@ template <typename Dtype>
 void caffe_exp(const int n, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
+void caffe_abs(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
 Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
 
 template <typename Dtype>
@@ -197,6 +200,9 @@ template <typename Dtype>
 void caffe_gpu_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
 
 template <typename Dtype>
+void caffe_gpu_abs(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
 void caffe_gpu_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
 
 // caffe_gpu_rng_uniform with two arguments generates integers in the range
index d72bcd2..32fdbf7 100644 (file)
@@ -33,6 +33,7 @@ extern "C" {
 
 DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
 DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
+DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
 
 // A simple way to define the vsl unary functions with singular parameter b.
 // The operation should be in the form e.g. y[i] = pow(a[i], b)
index 2170c19..d18d246 100644 (file)
@@ -19,6 +19,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   switch (type) {
   case LayerParameter_LayerType_ACCURACY:
     return new AccuracyLayer<Dtype>(param);
+  case LayerParameter_LayerType_ABSVAL:
+    return new AbsValLayer<Dtype>(param);
   case LayerParameter_LayerType_ARGMAX:
     return new ArgMaxLayer<Dtype>(param);
   case LayerParameter_LayerType_BNLL:
diff --git a/src/caffe/layers/absval_layer.cpp b/src/caffe/layers/absval_layer.cpp
new file mode 100644 (file)
index 0000000..ce9d05c
--- /dev/null
@@ -0,0 +1,46 @@
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/neuron_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AbsValLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+  CHECK_NE((*top)[0], bottom[0]) << this->type_name() << " Layer does not "
+    "allow in-place computation.";
+}
+
+template <typename Dtype>
+void AbsValLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  const int count = (*top)[0]->count();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  caffe_abs(count, bottom[0]->cpu_data(), top_data);
+}
+
+template <typename Dtype>
+void AbsValLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const int count = top[0]->count();
+  const Dtype* top_data = top[0]->cpu_data();
+  const Dtype* top_diff = top[0]->cpu_diff();
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    caffe_div(count, top_data, bottom_data, bottom_diff);
+    caffe_mul(count, bottom_diff, top_diff, bottom_diff);
+  }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(AbsValLayer);
+#endif
+
+INSTANTIATE_CLASS(AbsValLayer);
+
+
+}  // namespace caffe
diff --git a/src/caffe/layers/absval_layer.cu b/src/caffe/layers/absval_layer.cu
new file mode 100644 (file)
index 0000000..46778aa
--- /dev/null
@@ -0,0 +1,34 @@
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AbsValLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  const int count = (*top)[0]->count();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  caffe_gpu_abs(count, bottom[0]->gpu_data(), top_data);
+}
+
+template <typename Dtype>
+void AbsValLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const int count = top[0]->count();
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    caffe_gpu_div(count, top_data, bottom_data, bottom_diff);
+    caffe_gpu_mul(count, bottom_diff, top_diff, bottom_diff);
+  }
+}
+
+INSTANTIATE_CLASS(AbsValLayer);
+
+
+}  // namespace caffe
index 971a291..9cc0de9 100644 (file)
@@ -205,12 +205,13 @@ message LayerParameter {
   // line above the enum. Update the next available ID when you add a new
   // LayerType.
   //
-  // LayerType next available ID: 35 (last added: MVN)
+  // LayerType next available ID: 36 (last added: ABSVAL)
   enum LayerType {
     // "NONE" layer type is 0th enum element so that we don't cause confusion
     // by defaulting to an existent LayerType (instead, should usually error if
     // the type is unspecified).
     NONE = 0;
+    ABSVAL = 35;
     ACCURACY = 1;
     ARGMAX = 30;
     BNLL = 2;
index d10e702..667f744 100644 (file)
@@ -113,7 +113,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) {
 TYPED_TEST(MathFunctionsTest, TestFabsCPU) {
   int n = this->blob_bottom_->count();
   const TypeParam* x = this->blob_bottom_->cpu_data();
-  caffe_cpu_fabs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
+  caffe_abs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
   const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
   for (int i = 0; i < n; ++i) {
     EXPECT_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
@@ -194,7 +194,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) {
 
 TYPED_TEST(MathFunctionsTest, TestFabsGPU) {
   int n = this->blob_bottom_->count();
-  caffe_gpu_fabs<TypeParam>(n, this->blob_bottom_->gpu_data(),
+  caffe_gpu_abs<TypeParam>(n, this->blob_bottom_->gpu_data(),
                             this->blob_bottom_->mutable_gpu_diff());
   const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
   const TypeParam* x = this->blob_bottom_->cpu_data();
index 649f8f6..29dcec5 100644 (file)
@@ -70,6 +70,29 @@ class NeuronLayerTest : public MultiDeviceTest<TypeParam> {
 
 TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices);
 
+TYPED_TEST(NeuronLayerTest, TestAbsVal) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  AbsValLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data    = this->blob_top_->cpu_data();
+  const int count = this->blob_bottom_->count();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(top_data[i], fabs(bottom_data[i]));
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  AbsValLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
 TYPED_TEST(NeuronLayerTest, TestReLU) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
index e10f019..bac06f8 100644 (file)
@@ -206,6 +206,16 @@ void caffe_exp<double>(const int n, const double* a, double* y) {
   vdExp(n, a, y);
 }
 
+template <>
+void caffe_abs<float>(const int n, const float* a, float* y) {
+    vsAbs(n, a, y);
+}
+
+template <>
+void caffe_abs<double>(const int n, const double* a, double* y) {
+    vdAbs(n, a, y);
+}
+
 unsigned int caffe_rng_rand() {
   return (*caffe_rng())();
 }
@@ -349,7 +359,6 @@ double caffe_cpu_asum<double>(const int n, const double* x) {
 
 INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
 INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sgnbit);
-INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);
 
 template <>
 void caffe_cpu_scale<float>(const int n, const float alpha, const float *x,
index eacbb47..4ae4bba 100644 (file)
@@ -282,6 +282,28 @@ void caffe_gpu_div<double>(const int N, const double* a,
 }
 
 template <typename Dtype>
+__global__ void abs_kernel(const int n, const Dtype* a, Dtype* y) {
+  CUDA_KERNEL_LOOP(index, n) {
+    y[index] = abs(a[index]);
+  }
+}
+
+template <>
+void caffe_gpu_abs<float>(const int N, const float* a, float* y) {
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  abs_kernel<float><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, y);
+}
+
+template <>
+void caffe_gpu_abs<double>(const int N, const double* a, double* y) {
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  abs_kernel<double><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, y);
+}
+
+
+template <typename Dtype>
 __global__ void powx_kernel(const int n, const Dtype* a,
     const Dtype alpha, Dtype* y) {
   CUDA_KERNEL_LOOP(index, n) {
@@ -308,7 +330,6 @@ void caffe_gpu_powx<double>(const int N, const double* a,
 DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index])
                                       - (x[index] < Dtype(0)));
 DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = signbit(x[index]));
-DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index]));
 
 __global__ void popc_kernel(const int n, const float* a,
     const float* b, uint8_t* y) {