do all caffe_copy() as UVA mem copy, and drop caffe_gpu_copy()
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 28 Jun 2014 04:25:36 +0000 (21:25 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 4 Jul 2014 00:14:12 +0000 (17:14 -0700)
Do all memory copies by `cudaMemcpy` in UVA mode so that the same
`caffe_copy()` interface works for all transfers.

`cudaMemcpy()` is used in lieu of BLAS copies because they do not
understand UVA.

Drop the now unnecessary `caffe_gpu_copy()` since location of the
pointers is now irrelevant to the interface.

include/caffe/util/math_functions.hpp
src/caffe/util/math_functions.cpp

index 9951997..b5f4dfb 100644 (file)
@@ -59,6 +59,8 @@ void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
 template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
+void caffe_copy(const size_t N, const void *X, void *Y);
+
 template <typename Dtype>
 void caffe_set(const int N, const Dtype alpha, Dtype *X);
 
@@ -66,9 +68,6 @@ template <typename Dtype>
 void caffe_gpu_set(const int N, const Dtype alpha, Dtype *X);
 
 template <typename Dtype>
-void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
-
-template <typename Dtype>
 void caffe_add_scalar(const int N, const Dtype alpha, Dtype *X);
 
 template <typename Dtype>
index 90df512..b1b62ed 100644 (file)
@@ -150,30 +150,35 @@ void caffe_add_scalar(const int N, const double alpha, double* Y) {
 }
 
 template <>
-void caffe_copy<float>(const int N, const float* X, float* Y) {
+void caffe_copy<int>(const int N, const int* X, int* Y) {
   if (X != Y) {
-    cblas_scopy(N, X, 1, Y, 1);
+    CUDA_CHECK(cudaMemcpy(Y, X, sizeof(int) * N, cudaMemcpyDefault));
   }
 }
 
 template <>
-void caffe_copy<double>(const int N, const double* X, double* Y) {
+void caffe_copy<unsigned int>(const int N, const unsigned int* X,
+    unsigned int* Y) {
   if (X != Y) {
-    cblas_dcopy(N, X, 1, Y, 1);
+  CUDA_CHECK(cudaMemcpy(Y, X, sizeof(unsigned int) * N, cudaMemcpyDefault));
   }
 }
 
 template <>
-void caffe_gpu_copy<float>(const int N, const float* X, float* Y) {
+void caffe_copy<float>(const int N, const float* X, float* Y) {
   if (X != Y) {
-    CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+    CUDA_CHECK(cudaMemcpy(Y, X, sizeof(float) * N, cudaMemcpyDefault));
   }
 }
 
 template <>
-void caffe_gpu_copy<double>(const int N, const double* X, double* Y) {
+void caffe_copy<double>(const int N, const double* X, double* Y) {
+  CUDA_CHECK(cudaMemcpy(Y, X, sizeof(double) * N, cudaMemcpyDefault));
+}
+
+void caffe_copy(const size_t N, const void* X, void* Y) {
   if (X != Y) {
-    CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+    CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
   }
 }