Add signbit math func, simplify GPU defs & instantiations with a macro
authorKai Li <kaili_kloud@163.com>
Wed, 26 Feb 2014 03:23:20 +0000 (11:23 +0800)
committerKai Li <kaili_kloud@163.com>
Sun, 23 Mar 2014 13:25:58 +0000 (21:25 +0800)
include/caffe/util/math_functions.hpp
src/caffe/test/test_math_functions.cpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.cu

index 5d4a8e9..268cb2b 100644 (file)
@@ -5,6 +5,7 @@
 #define CAFFE_UTIL_MATH_FUNCTIONS_H_
 
 #include <cmath> // for std::fabs
+#include <math.h> // for signbit
 #include <cublas_v2.h>
 
 #include "caffe/util/mkl_alternate.hpp"
@@ -147,11 +148,38 @@ inline char caffe_sign(Dtype val) {
   template <> \
   void caffe_cpu_##name<double>(const int n, const double* x, double* y)
 
+
+#define DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(name, operation) \
+template<typename Dtype> \
+__global__ void name##_kernel(const int n, const Dtype* x, Dtype* y) { \
+  int index = threadIdx.x + blockIdx.x * blockDim.x; \
+  if (index < n) { \
+    operation; \
+  } \
+} \
+template <> \
+void caffe_gpu_##name<float>(const int n, const float* x, float* y) { \
+  name##_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
+      n, x, y); \
+} \
+template <> \
+void caffe_gpu_##name<double>(const int n, const double* x, double* y) { \
+  name##_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
+      n, x, y); \
+}
+
+// output is 1 for the positives, 0 for zero, and -1 for the negatives
 DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign<Dtype>(x[i]));
 
 template<typename Dtype>
 void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
 
+// returns a nonzero value is the input has its sign bit set.
+DEFINE_CAFFE_CPU_UNARY_FUNC(signbit, y[i] = std::signbit(x[i]));
+
+template<typename Dtype>
+void caffe_gpu_signbit(const int n, const Dtype* x, Dtype* y);
+
 DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i]));
 
 template <typename Dtype>
index 00f28ba..d314d73 100644 (file)
@@ -119,6 +119,27 @@ TYPED_TEST(MathFunctionsTest, TestSignGPU){
   }
 }
 
+TYPED_TEST(MathFunctionsTest, TestSignbitCPU){
+  int n = this->blob_bottom_->count();
+  const TypeParam* x = this->blob_bottom_->cpu_data();
+  caffe_cpu_signbit<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
+  const TypeParam* signbits = this->blob_bottom_->cpu_diff();
+  for (int i = 0; i < n; ++i) {
+    CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
+  }
+}
+
+TYPED_TEST(MathFunctionsTest, TestSignbitGPU){
+  int n = this->blob_bottom_->count();
+  caffe_gpu_signbit<TypeParam>(n, this->blob_bottom_->gpu_data(),
+                            this->blob_bottom_->mutable_gpu_diff());
+  const TypeParam* signbits = this->blob_bottom_->cpu_diff();
+  const TypeParam* x = this->blob_bottom_->cpu_data();
+  for (int i = 0; i < n; ++i) {
+    CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
+  }
+}
+
 TYPED_TEST(MathFunctionsTest, TestFabsCPU){
   int n = this->blob_bottom_->count();
   const TypeParam* x = this->blob_bottom_->cpu_data();
index ef347a1..ad83a99 100644 (file)
@@ -411,6 +411,7 @@ void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
 }
 
 INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
+INSTANTIATE_CAFFE_CPU_UNARY_FUNC(signbit);
 INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);
 
 template <>
index 72cbb00..e3eaacc 100644 (file)
@@ -4,7 +4,7 @@
 #include <cmath>
 #include <cstdlib>
 #include <cstring>
-#include <math_functions.h> // CUDA's, not caffe's, for fabs
+#include <math_functions.h> // CUDA's, not caffe's, for fabs, signbit
 
 #include "caffe/common.hpp"
 #include "caffe/util/math_functions.hpp"
@@ -35,44 +35,8 @@ void caffe_gpu_mul<double>(const int N, const double* a,
       N, a, b, y);
 }
 
-template<typename Dtype>
-__global__ void sign_kernel(const int n, const Dtype* x, Dtype* y) {
-  int index = threadIdx.x + blockIdx.x * blockDim.x;
-  if (index < n) {
-    y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0));
-  }
-}
-
-template <>
-void caffe_gpu_sign<float>(const int n, const float* x, float* y) {
-  sign_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
-      n, x, y);
-}
-
-template <>
-void caffe_gpu_sign<double>(const int n, const double* x, double* y) {
-  sign_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
-      n, x, y);
-}
-
-template<typename Dtype>
-__global__ void fabs_kernel(const int n, const Dtype* x, Dtype* y) {
-  int index = threadIdx.x + blockIdx.x * blockDim.x;
-  if (index < n) {
-    y[index] = fabs(x[index]);
-  }
-}
-
-template <>
-void caffe_gpu_fabs<float>(const int n, const float* x, float* y) {
-  fabs_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
-      n, x, y);
-}
-
-template <>
-void caffe_gpu_fabs<double>(const int n, const double* x, double* y) {
-  fabs_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(
-      n, x, y);
-}
+DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0)));
+DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(signbit, y[index] = signbit(x[index]));
+DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index]));
 
 }  // namespace caffe