Replace cudaMemcpy with caffe_gpu_memcpy in SyncedMemory per @longjon
authorKai Li <kaili_kloud@163.com>
Wed, 9 Jul 2014 23:50:31 +0000 (07:50 +0800)
committerKai Li <kaili_kloud@163.com>
Thu, 10 Jul 2014 00:03:22 +0000 (08:03 +0800)
include/caffe/util/math_functions.hpp
src/caffe/syncedmem.cpp
src/caffe/test/test_syncedmem.cpp
src/caffe/util/math_functions.cpp

index 97a0571..2df0fc9 100644 (file)
@@ -59,7 +59,7 @@ void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
 template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
-void caffe_memcpy(const size_t N, const void *X, void *Y);
+void caffe_gpu_memcpy(const size_t N, const void *X, void *Y);
 
 template <typename Dtype>
 void caffe_set(const int N, const Dtype alpha, Dtype *X);
index 3f9a3be..77dfe7a 100644 (file)
@@ -33,7 +33,7 @@ inline void SyncedMemory::to_cpu() {
       CaffeMallocHost(&cpu_ptr_, size_);
       own_cpu_data_ = true;
     }
-    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
+    caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_CPU:
@@ -53,7 +53,7 @@ inline void SyncedMemory::to_gpu() {
     if (gpu_ptr_ == NULL) {
       CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     }
-    CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault));
+    caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_GPU:
index 3aaeafc..3a75708 100644 (file)
@@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
   char* recovered_value = new char[10];
-  caffe_memcpy(10, gpu_data, recovered_value);
+  caffe_gpu_memcpy(10, gpu_data, recovered_value);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1);
   }
@@ -72,7 +72,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   gpu_data = mem.gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
-  caffe_memcpy(10, gpu_data, recovered_value);
+  caffe_gpu_memcpy(10, gpu_data, recovered_value);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2);
   }
index 9311a39..b989ca2 100644 (file)
@@ -166,13 +166,9 @@ template void caffe_copy<unsigned int>(const int N, const unsigned int* X,
 template void caffe_copy<float>(const int N, const float* X, float* Y);
 template void caffe_copy<double>(const int N, const double* X, double* Y);
 
-void caffe_memcpy(const size_t N, const void* X, void* Y) {
+void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
   if (X != Y) {
-    if (Caffe::mode() == Caffe::GPU) {
-      CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
-    } else {
-      memcpy(Y, X, N);
-    }
+    CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
   }
 }