--- /dev/null
+// Copyright 2013 Rowland Depp
+
+#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
+#define CAFFE_UTIL_MKL_ALTERNATE_H_
+
+#ifdef USE_MKL
+
+#include <mkl.h>
+
+#else // If use MKL, simply include the MKL header
+
+#include <cblas.h>
+#include <math.h>
+
+// Functions that caffe uses but are not present if MKL is not linked.
+
+// A simple way to define the vsl unary functions. The operation should
+// be in the form e.g. y[i] = sqrt(a[i])
+#define DEFINE_VSL_UNARY_FUNC(name, operation) \
+ template<typename Dtype> \
+ void v##name(const int n, const Dtype* a, Dtype* y) { \
+ CHECK_GT(n, 0); CHECK(a); CHECK(y); \
+ for (int i = 0; i < n; ++i) { operation; } \
+ } \
+ inline void vs##name( \
+ const int n, const float* a, float* y) { \
+ v##name<float>(n, a, y); \
+ } \
+ inline void vd##name( \
+ const int n, const double* a, double* y) { \
+ v##name<double>(n, a, y); \
+ }
+
+DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
+DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
+
+// A simple way to define the vsl unary functions with singular parameter b.
+// The operation should be in the form e.g. y[i] = pow(a[i], b)
+#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
+ template<typename Dtype> \
+ void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
+ CHECK_GT(n, 0); CHECK(a); CHECK(y); \
+ for (int i = 0; i < n; ++i) { operation; } \
+ } \
+ inline void vs##name( \
+ const int n, const float* a, const float b, float* y) { \
+ v##name<float>(n, a, b, y); \
+ } \
+ inline void vd##name( \
+ const int n, const double* a, const float b, double* y) { \
+ v##name<double>(n, a, b, y); \
+ }
+
+DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));
+
+// A simple way to define the vsl binary functions. The operation should
+// be in the form e.g. y[i] = a[i] + b[i]
+#define DEFINE_VSL_BINARY_FUNC(name, operation) \
+ template<typename Dtype> \
+ void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
+ CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
+ for (int i = 0; i < n; ++i) { operation; } \
+ } \
+ inline void vs##name( \
+ const int n, const float* a, const float* b, float* y) { \
+ v##name<float>(n, a, b, y); \
+ } \
+ inline void vd##name( \
+ const int n, const double* a, const double* b, double* y) { \
+ v##name<double>(n, a, b, y); \
+ }
+
+DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
+DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
+DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
+DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);
+
+// In addition, MKL comes with an additional function axpby that is not present
+// in standard blas. We will simply use a two-step (inefficient, of course) way
+// to mimic that.
+inline void cblas_saxpby(const int N, const float alpha, const float* X,
+ const int incX, const float beta, float* Y,
+ const int incY) {
+ cblas_sscal(N, beta, Y, incY);
+ cblas_saxpy(N, alpha, X, incX, Y, incY);
+}
+inline void cblas_daxpby(const int N, const double alpha, const double* X,
+ const int incX, const double beta, double* Y,
+ const int incY) {
+ cblas_dscal(N, beta, Y, incY);
+ cblas_daxpy(N, alpha, X, incX, Y, incY);
+}
+
+#endif // USE_MKL
+#endif // CAFFE_UTIL_MKL_ALTERNATE_H_
#include <limits>
//#include <mkl.h>
-#include <eigen3/Eigen/Dense>
#include <boost/math/special_functions/next.hpp>
#include <boost/random.hpp>
namespace caffe {
-// Operations on aligned memory are faster than on unaligned memory.
-// But unfortunately, the pointers passed in are not always aligned.
-// Therefore, the memory-aligned Eigen::Map objects that wrap them
-// cannot be assigned to. This happens in lrn_layer and makes
-// test_lrn_layer crash with segmentation fault.
-// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned.
-
-// Though the default map option is unaligned, making it explicit is no harm.
-//const int data_alignment = Eigen::Aligned; // how is data allocated ?
-const int data_alignment = Eigen::Unaligned;
-typedef Eigen::Array<float, 1, Eigen::Dynamic> float_array_t;
-typedef Eigen::Map<const float_array_t, data_alignment> const_map_vector_float_t;
-typedef Eigen::Map<float_array_t, data_alignment> map_vector_float_t;
-typedef Eigen::Array<double, 1, Eigen::Dynamic> double_array_t;
-typedef Eigen::Map<const double_array_t, data_alignment> const_map_vector_double_t;
-typedef Eigen::Map<double_array_t, data_alignment> map_vector_double_t;
-
template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
void caffe_axpy<double>(const int N, const double alpha, const double* X,
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
-
template <>
void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
float* Y) {
}
template <>
-void caffe_axpby<float>(const int N, const float alpha, const float* X,
- const float beta, float* Y) {
- // y := a*x + b*y
- //cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
- CHECK_GE(N, 0);
- CHECK(X);
- CHECK(Y);
- map_vector_float_t y_map(Y, N);
- // Eigen produces optimized code using lasy evaluation
- // http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html
- y_map = const_map_vector_float_t(X, N) * alpha + y_map * beta;
+void caffe_cpu_axpby<float>(const int N, const float alpha, const float* X,
+ const float beta, float* Y) {
+ cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
-void caffe_axpby<double>(const int N, const double alpha, const double* X,
- const double beta, double* Y) {
- // y := a*x + b*y
- //cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
- CHECK_GE(N, 0);
- CHECK(X);
- CHECK(Y);
- map_vector_double_t y_map(Y, N);
- y_map = const_map_vector_double_t(X, N) * alpha + y_map * beta;
+void caffe_cpu_axpby<double>(const int N, const double alpha, const double* X,
+ const double beta, double* Y) {
+ cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
}
template <>
void caffe_add<float>(const int n, const float* a, const float* b,
float* y) {
- //vsAdd(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) +
- const_map_vector_float_t(b, n);
+ vsAdd(n, a, b, y);
}
template <>
void caffe_add<double>(const int n, const double* a, const double* b,
double* y) {
- //vdAdd(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) +
- const_map_vector_double_t(b, n);
+ vdAdd(n, a, b, y);
}
template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
float* y) {
- //vsSub(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) -
- const_map_vector_float_t(b, n);
+ vsSub(n, a, b, y);
}
template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
double* y) {
- //vdSub(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) -
- const_map_vector_double_t(b, n);
+ vdSub(n, a, b, y);
}
template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
float* y) {
- //vsMul(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) *
- const_map_vector_float_t(b, n);
+ vsMul(n, a, b, y);
}
template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
double* y) {
- //vdMul(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) *
- const_map_vector_double_t(b, n);
+ vdMul(n, a, b, y);
}
template <>
void caffe_div<float>(const int n, const float* a, const float* b,
float* y) {
- //vsDiv(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) /
- const_map_vector_float_t(b, n);
+ vsDiv(n, a, b, y);
}
template <>
void caffe_div<double>(const int n, const double* a, const double* b,
double* y) {
- //vdDiv(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(b);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) /
- const_map_vector_double_t(b, n);
+ vdDiv(n, a, b, y);
}
template <>
void caffe_powx<float>(const int n, const float* a, const float b,
float* y) {
- //vsPowx(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).pow(b);
+ vsPowx(n, a, b, y);
}
template <>
void caffe_powx<double>(const int n, const double* a, const double b,
double* y) {
- //vdPowx(n, a, b, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).pow(b);
+ vdPowx(n, a, b, y);
}
template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
- // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-F003F826-81BF-42EC-AE51-2EF624893133.htm
- // v?Sqr Performs element by element squaring of the vector.
- //vsSqr(n, a, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- caffe_powx<float>(n, a, 2, y);
- // TODO: which is faster?
-// map_vector_float_t(y, n) = const_map_vector_float_t(a, n) *
-// const_map_vector_float_t(a, n);
+ vsSqr(n, a, y);
}
template <>
void caffe_sqr<double>(const int n, const double* a, double* y) {
- //vdSqr(n, a, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- caffe_powx<double>(n, a, 2, y);
+ vdSqr(n, a, y);
}
template <>
void caffe_exp<float>(const int n, const float* a, float* y) {
- //vsExp(n, a, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).exp();
+ vsExp(n, a, y);
}
template <>
void caffe_exp<double>(const int n, const double* a, double* y) {
- //vdExp(n, a, y);
- CHECK_GE(n, 0);
- CHECK(a);
- CHECK(y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).exp();
+ vdExp(n, a, y);
}
template <typename Dtype>