add caffe_cpu_strided_dot for strided dot products
authorJonathan L Long <jonlong@cs.berkeley.edu>
Mon, 18 Aug 2014 06:15:03 +0000 (23:15 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Mon, 18 Aug 2014 07:04:29 +0000 (00:04 -0700)
This provides a more direct interface to the cblas_?dot functions.
This is useful, for example, for taking dot products across channels.

include/caffe/util/math_functions.hpp
src/caffe/util/math_functions.cpp

index 6a608d5..f4310bf 100644 (file)
@@ -95,6 +95,10 @@ template <typename Dtype>
 Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
 
 template <typename Dtype>
+Dtype caffe_cpu_strided_dot(const int n, const Dtype* x, const int incx,
+    const Dtype* y, const int incy);
+
+template <typename Dtype>
 int caffe_cpu_hamming_distance(const int n, const Dtype* x, const Dtype* y);
 
 // Returns the sum of the absolute values of the elements of vector x
index bac06f8..7016e63 100644 (file)
@@ -315,14 +315,27 @@ void caffe_rng_bernoulli<double>(const int n, const double p, unsigned int* r);
 template
 void caffe_rng_bernoulli<float>(const int n, const float p, unsigned int* r);
 
+template <typename Dtype>
+Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y) {
+  return caffe_cpu_strided_dot(n, x, 1, y, 1);
+}
+
+template
+float caffe_cpu_dot<float>(const int n, const float* x, const float* y);
+
+template
+double caffe_cpu_dot<double>(const int n, const double* x, const double* y);
+
 template <>
-float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
-  return cblas_sdot(n, x, 1, y, 1);
+float caffe_cpu_strided_dot<float>(const int n, const float* x, const int incx,
+    const float* y, const int incy) {
+  return cblas_sdot(n, x, incx, y, incy);
 }
 
 template <>
-double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
-  return cblas_ddot(n, x, 1, y, 1);
+double caffe_cpu_strided_dot<double>(const int n, const double* x,
+    const int incx, const double* y, const int incy) {
+  return cblas_ddot(n, x, incx, y, incy);
 }
 
 template <>