From 1991826a79e7f6a4bc99a8ecc85878cacac92a53 Mon Sep 17 00:00:00 2001 From: Alireza Shafaei Date: Sat, 9 Aug 2014 22:44:12 -0700 Subject: [PATCH] Added absolute value layer, useful for implementation of siamese networks! This commit also replaces the default caffe_fabs with MKL/non-MKL implementation of Abs. --- include/caffe/neuron_layers.hpp | 30 ++++++++++++++++++++++ include/caffe/util/math_functions.hpp | 6 +++++ include/caffe/util/mkl_alternate.hpp | 1 + src/caffe/layer_factory.cpp | 2 ++ src/caffe/layers/absval_layer.cpp | 46 ++++++++++++++++++++++++++++++++++ src/caffe/layers/absval_layer.cu | 34 +++++++++++++++++++++++++ src/caffe/proto/caffe.proto | 3 ++- src/caffe/test/test_math_functions.cpp | 4 +-- src/caffe/test/test_neuron_layer.cpp | 23 +++++++++++++++++ src/caffe/util/math_functions.cpp | 11 +++++++- src/caffe/util/math_functions.cu | 23 ++++++++++++++++- 11 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 src/caffe/layers/absval_layer.cpp create mode 100644 src/caffe/layers/absval_layer.cu diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index 20f7f6d..d17120f 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -38,6 +38,36 @@ class NeuronLayer : public Layer { virtual inline int ExactNumTopBlobs() const { return 1; } }; +/* AbsVal Layer + y = |x| + + y' = 1 if x > 0 + = -1 if x < 0 +*/ +template +class AbsValLayer : public NeuronLayer { + public: + explicit AbsValLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + vector*>* top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_ABSVAL; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom); +}; + /* BNLLLayer y = x + log(1 + exp(-x)) if x > 0 diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 90a1a86..6a608d5 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -89,6 +89,9 @@ template void caffe_exp(const int n, const Dtype* a, Dtype* y); template +void caffe_abs(const int n, const Dtype* a, Dtype* y); + +template Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y); template @@ -197,6 +200,9 @@ template void caffe_gpu_div(const int N, const Dtype* a, const Dtype* b, Dtype* y); template +void caffe_gpu_abs(const int n, const Dtype* a, Dtype* y); + +template void caffe_gpu_powx(const int n, const Dtype* a, const Dtype b, Dtype* y); // caffe_gpu_rng_uniform with two arguments generates integers in the range diff --git a/include/caffe/util/mkl_alternate.hpp b/include/caffe/util/mkl_alternate.hpp index d72bcd2..32fdbf7 100644 --- a/include/caffe/util/mkl_alternate.hpp +++ b/include/caffe/util/mkl_alternate.hpp @@ -33,6 +33,7 @@ extern "C" { DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]); DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i])); +DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i])); // A simple way to define the vsl unary functions with singular parameter b. // The operation should be in the form e.g. y[i] = pow(a[i], b) diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 2170c19..d18d246 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -19,6 +19,8 @@ Layer* GetLayer(const LayerParameter& param) { switch (type) { case LayerParameter_LayerType_ACCURACY: return new AccuracyLayer(param); + case LayerParameter_LayerType_ABSVAL: + return new AbsValLayer(param); case LayerParameter_LayerType_ARGMAX: return new ArgMaxLayer(param); case LayerParameter_LayerType_BNLL: diff --git a/src/caffe/layers/absval_layer.cpp b/src/caffe/layers/absval_layer.cpp new file mode 100644 index 0000000..ce9d05c --- /dev/null +++ b/src/caffe/layers/absval_layer.cpp @@ -0,0 +1,46 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/neuron_layers.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void AbsValLayer::LayerSetUp(const vector*>& bottom, + vector*>* top) { + NeuronLayer::LayerSetUp(bottom, top); + CHECK_NE((*top)[0], bottom[0]) << this->type_name() << " Layer does not " + "allow in-place computation."; +} + +template +void AbsValLayer::Forward_cpu( + const vector*>& bottom, vector*>* top) { + const int count = (*top)[0]->count(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + caffe_abs(count, bottom[0]->cpu_data(), top_data); +} + +template +void AbsValLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom) { + const int count = top[0]->count(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + if (propagate_down[0]) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + caffe_div(count, top_data, bottom_data, bottom_diff); + caffe_mul(count, bottom_diff, top_diff, bottom_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(AbsValLayer); +#endif + +INSTANTIATE_CLASS(AbsValLayer); + + +} // namespace caffe diff --git a/src/caffe/layers/absval_layer.cu b/src/caffe/layers/absval_layer.cu new file mode 100644 index 0000000..46778aa --- /dev/null +++ b/src/caffe/layers/absval_layer.cu @@ -0,0 +1,34 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void AbsValLayer::Forward_gpu( + const vector*>& bottom, vector*>* top) { + const int count = (*top)[0]->count(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + caffe_gpu_abs(count, bottom[0]->gpu_data(), top_data); +} + +template +void AbsValLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom) { + const int count = top[0]->count(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + if (propagate_down[0]) { + const Dtype* bottom_data = (*bottom)[0]->gpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + caffe_gpu_div(count, top_data, bottom_data, bottom_diff); + caffe_gpu_mul(count, bottom_diff, top_diff, bottom_diff); + } +} + +INSTANTIATE_CLASS(AbsValLayer); + + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 971a291..9cc0de9 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -205,12 +205,13 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 35 (last added: MVN) + // LayerType next available ID: 36 (last added: ABSVAL) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if // the type is unspecified). NONE = 0; + ABSVAL = 35; ACCURACY = 1; ARGMAX = 30; BNLL = 2; diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp index d10e702..667f744 100644 --- a/src/caffe/test/test_math_functions.cpp +++ b/src/caffe/test/test_math_functions.cpp @@ -113,7 +113,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) { TYPED_TEST(MathFunctionsTest, TestFabsCPU) { int n = this->blob_bottom_->count(); const TypeParam* x = this->blob_bottom_->cpu_data(); - caffe_cpu_fabs(n, x, this->blob_bottom_->mutable_cpu_diff()); + caffe_abs(n, x, this->blob_bottom_->mutable_cpu_diff()); const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); for (int i = 0; i < n; ++i) { EXPECT_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]); @@ -194,7 +194,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) { TYPED_TEST(MathFunctionsTest, TestFabsGPU) { int n = this->blob_bottom_->count(); - caffe_gpu_fabs(n, this->blob_bottom_->gpu_data(), + caffe_gpu_abs(n, this->blob_bottom_->gpu_data(), this->blob_bottom_->mutable_gpu_diff()); const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); const TypeParam* x = this->blob_bottom_->cpu_data(); diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index 649f8f6..29dcec5 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -70,6 +70,29 @@ class NeuronLayerTest : public MultiDeviceTest { TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices); +TYPED_TEST(NeuronLayerTest, TestAbsVal) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + AbsValLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const int count = this->blob_bottom_->count(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(top_data[i], fabs(bottom_data[i])); + } +} + +TYPED_TEST(NeuronLayerTest, TestAbsGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + AbsValLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_), + &(this->blob_top_vec_)); +} + TYPED_TEST(NeuronLayerTest, TestReLU) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index e10f019..bac06f8 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -206,6 +206,16 @@ void caffe_exp(const int n, const double* a, double* y) { vdExp(n, a, y); } +template <> +void caffe_abs(const int n, const float* a, float* y) { + vsAbs(n, a, y); +} + +template <> +void caffe_abs(const int n, const double* a, double* y) { + vdAbs(n, a, y); +} + unsigned int caffe_rng_rand() { return (*caffe_rng())(); } @@ -349,7 +359,6 @@ double caffe_cpu_asum(const int n, const double* x) { INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign); INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sgnbit); -INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs); template <> void caffe_cpu_scale(const int n, const float alpha, const float *x, diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index eacbb47..4ae4bba 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -282,6 +282,28 @@ void caffe_gpu_div(const int N, const double* a, } template +__global__ void abs_kernel(const int n, const Dtype* a, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = abs(a[index]); + } +} + +template <> +void caffe_gpu_abs(const int N, const float* a, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + abs_kernel<<>>( + N, a, y); +} + +template <> +void caffe_gpu_abs(const int N, const double* a, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + abs_kernel<<>>( + N, a, y); +} + + +template __global__ void powx_kernel(const int n, const Dtype* a, const Dtype alpha, Dtype* y) { CUDA_KERNEL_LOOP(index, n) { @@ -308,7 +330,6 @@ void caffe_gpu_powx(const int N, const double* a, DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0))); DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = signbit(x[index])); -DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index])); __global__ void popc_kernel(const int n, const float* a, const float* b, uint8_t* y) { -- 2.7.4