reduce caffe_copy to instantiations, split off caffe_memcpy for void*
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 28 Jun 2014 20:48:37 +0000 (13:48 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 4 Jul 2014 00:20:30 +0000 (17:20 -0700)
include/caffe/util/math_functions.hpp
src/caffe/layers/dropout_layer.cu
src/caffe/syncedmem.cpp
src/caffe/test/test_syncedmem.cpp
src/caffe/util/math_functions.cpp

index b5f4dfb..97a0571 100644 (file)
@@ -59,7 +59,7 @@ void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
 template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
-void caffe_copy(const size_t N, const void *X, void *Y);
+void caffe_memcpy(const size_t N, const void *X, void *Y);
 
 template <typename Dtype>
 void caffe_set(const int N, const Dtype alpha, Dtype *X);
index 11a4263..c9f3ecd 100644 (file)
@@ -71,7 +71,7 @@ void DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
           count, top_diff, mask, uint_thres_, scale_, bottom_diff);
       CUDA_POST_KERNEL_CHECK;
     } else {
-      caffe_gpu_copy(top[0]->count(), top_diff, bottom_diff);
+      caffe_copy(top[0]->count(), top_diff, bottom_diff);
     }
   }
 }
index 011d359..9fe5528 100644 (file)
@@ -33,7 +33,7 @@ inline void SyncedMemory::to_cpu() {
       CaffeMallocHost(&cpu_ptr_, size_);
       own_cpu_data_ = true;
     }
-    caffe_copy(size_, gpu_ptr_, cpu_ptr_);
+    caffe_memcpy(size_, gpu_ptr_, cpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_CPU:
@@ -53,7 +53,7 @@ inline void SyncedMemory::to_gpu() {
     if (gpu_ptr_ == NULL) {
       CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     }
-    caffe_copy(size_, cpu_ptr_, gpu_ptr_);
+    caffe_memcpy(size_, cpu_ptr_, gpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_GPU:
index a7a5131..3aaeafc 100644 (file)
@@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
   char* recovered_value = new char[10];
-  caffe_copy(size_t(10), gpu_data, reinterpret_cast<void*>(recovered_value));
+  caffe_memcpy(10, gpu_data, recovered_value);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1);
   }
@@ -72,7 +72,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   gpu_data = mem.gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
-  caffe_copy(size_t(10), gpu_data, reinterpret_cast<void*>(recovered_value));
+  caffe_memcpy(10, gpu_data, recovered_value);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2);
   }
index b1b62ed..918bb3c 100644 (file)
@@ -149,34 +149,20 @@ void caffe_add_scalar(const int N, const double alpha, double* Y) {
   }
 }
 
-template <>
-void caffe_copy<int>(const int N, const int* X, int* Y) {
-  if (X != Y) {
-    CUDA_CHECK(cudaMemcpy(Y, X, sizeof(int) * N, cudaMemcpyDefault));
-  }
-}
-
-template <>
-void caffe_copy<unsigned int>(const int N, const unsigned int* X,
-    unsigned int* Y) {
-  if (X != Y) {
-  CUDA_CHECK(cudaMemcpy(Y, X, sizeof(unsigned int) * N, cudaMemcpyDefault));
-  }
-}
-
-template <>
-void caffe_copy<float>(const int N, const float* X, float* Y) {
+template <typename Dtype>
+void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
   if (X != Y) {
-    CUDA_CHECK(cudaMemcpy(Y, X, sizeof(float) * N, cudaMemcpyDefault));
+    CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
   }
 }
 
-template <>
-void caffe_copy<double>(const int N, const double* X, double* Y) {
-  CUDA_CHECK(cudaMemcpy(Y, X, sizeof(double) * N, cudaMemcpyDefault));
-}
+template void caffe_copy<int>(const int N, const int* X, int* Y);
+template void caffe_copy<unsigned int>(const int N, const unsigned int* X,
+    unsigned int* Y);
+template void caffe_copy<float>(const int N, const float* X, float* Y);
+template void caffe_copy<double>(const int N, const double* X, double* Y);
 
-void caffe_copy(const size_t N, const void* X, void* Y) {
+void caffe_memcpy(const size_t N, const void* X, void* Y) {
   if (X != Y) {
     CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
   }