switch to unified virtual addressing CUDA memcpy
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 28 Jun 2014 01:36:48 +0000 (18:36 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 4 Jul 2014 00:14:11 +0000 (17:14 -0700)
Host / device copies are distinguished by the virtual address of the
pointers instead of explicit memcpy modes.

matlab/caffe/matcaffe.cpp
src/caffe/blob.cpp
src/caffe/layers/data_layer.cu
src/caffe/layers/hdf5_data_layer.cu
src/caffe/layers/hdf5_output_layer.cu
src/caffe/layers/image_data_layer.cu
src/caffe/layers/window_data_layer.cu
src/caffe/syncedmem.cpp
src/caffe/test/test_syncedmem.cpp

index 21f51e8..1b2b65e 100644 (file)
@@ -59,7 +59,7 @@ static mxArray* do_forward(const mxArray* const bottom) {
       break;
     case Caffe::GPU:
       cudaMemcpy(input_blobs[i]->mutable_gpu_data(), data_ptr,
-          sizeof(float) * input_blobs[i]->count(), cudaMemcpyHostToDevice);
+          sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -82,7 +82,7 @@ static mxArray* do_forward(const mxArray* const bottom) {
       break;
     case Caffe::GPU:
       cudaMemcpy(data_ptr, output_blobs[i]->gpu_data(),
-          sizeof(float) * output_blobs[i]->count(), cudaMemcpyDeviceToHost);
+          sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -109,7 +109,7 @@ static mxArray* do_backward(const mxArray* const top_diff) {
       break;
     case Caffe::GPU:
       cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), data_ptr,
-        sizeof(float) * output_blobs[i]->count(), cudaMemcpyHostToDevice);
+        sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -134,7 +134,7 @@ static mxArray* do_backward(const mxArray* const top_diff) {
       break;
     case Caffe::GPU:
       cudaMemcpy(data_ptr, input_blobs[i]->gpu_diff(),
-          sizeof(float) * input_blobs[i]->count(), cudaMemcpyDeviceToHost);
+          sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -211,7 +211,7 @@ static mxArray* do_get_weights() {
           break;
         case Caffe::GPU:
           CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(),
-              sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDeviceToHost));
+              sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDefault));
           break;
         default:
           LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
index e603712..eb1075a 100644 (file)
@@ -150,10 +150,10 @@ void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
   case Caffe::GPU:
     if (copy_diff) {
       CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
-          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+          sizeof(Dtype) * count_, cudaMemcpyDefault));
     } else {
       CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
-          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+          sizeof(Dtype) * count_, cudaMemcpyDefault));
     }
     break;
   case Caffe::CPU:
index 2ff9a29..5a8bb0b 100644 (file)
@@ -23,11 +23,11 @@ Dtype DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // Copy the data
   CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
       prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyHostToDevice));
+      cudaMemcpyDefault));
   if (output_labels_) {
     CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
         prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-        cudaMemcpyHostToDevice));
+        cudaMemcpyDefault));
   }
   // Start a new prefetch thread
   CreatePrefetchThread();
index b2b09ef..232f55f 100644 (file)
@@ -44,12 +44,12 @@ Dtype HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
             &(*top)[0]->mutable_gpu_data()[i * data_count],
             &data_blob_.cpu_data()[current_row_ * data_count],
             sizeof(Dtype) * data_count,
-            cudaMemcpyHostToDevice));
+            cudaMemcpyDefault));
     CUDA_CHECK(cudaMemcpy(
             &(*top)[1]->mutable_gpu_data()[i * label_data_count],
             &label_blob_.cpu_data()[current_row_ * label_data_count],
             sizeof(Dtype) * label_data_count,
-            cudaMemcpyHostToDevice));
+            cudaMemcpyDefault));
   }
   return Dtype(0.);
 }
index 59505ee..19567c1 100644 (file)
@@ -29,10 +29,10 @@ Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   for (int i = 0; i < bottom[0]->num(); ++i) {
     CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
            &bottom[0]->gpu_data()[i * data_datum_dim],
-           sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost));
+           sizeof(Dtype) * data_datum_dim, cudaMemcpyDefault));
     CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
            &bottom[1]->gpu_data()[i * label_datum_dim],
-           sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost));
+           sizeof(Dtype) * label_datum_dim, cudaMemcpyDefault));
   }
   SaveBlobs();
   return Dtype(0.);
index 9804729..3f905bc 100644 (file)
@@ -29,10 +29,10 @@ Dtype ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // Copy the data
   CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
       prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyHostToDevice));
+      cudaMemcpyDefault));
   CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
       prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-      cudaMemcpyHostToDevice));
+      cudaMemcpyDefault));
   // Start a new prefetch thread
   CreatePrefetchThread();
   return Dtype(0.);
index bc49fef..5e8cdb7 100644 (file)
@@ -30,10 +30,10 @@ Dtype WindowDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // Copy the data
   CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
       prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyHostToDevice));
+      cudaMemcpyDefault));
   CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
       prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-      cudaMemcpyHostToDevice));
+      cudaMemcpyDefault));
   // Start a new prefetch thread
   CreatePrefetchThread();
   return Dtype(0.);
index fec37d6..ec4c528 100644 (file)
@@ -32,7 +32,7 @@ inline void SyncedMemory::to_cpu() {
       CaffeMallocHost(&cpu_ptr_, size_);
       own_cpu_data_ = true;
     }
-    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDeviceToHost));
+    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
     head_ = SYNCED;
     break;
   case HEAD_AT_CPU:
@@ -52,7 +52,7 @@ inline void SyncedMemory::to_gpu() {
     if (gpu_ptr_ == NULL) {
       CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     }
-    CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyHostToDevice));
+    CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault));
     head_ = SYNCED;
     break;
   case HEAD_AT_GPU:
index 7bbbbab..e741f0f 100644 (file)
@@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   // check if values are the same
   char* recovered_value = new char[10];
   cudaMemcpy(reinterpret_cast<void*>(recovered_value), gpu_data, 10,
-             cudaMemcpyDeviceToHost);
+             cudaMemcpyDefault);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1);
   }
@@ -73,7 +73,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
   cudaMemcpy(reinterpret_cast<void*>(recovered_value), gpu_data, 10,
-             cudaMemcpyDeviceToHost);
+             cudaMemcpyDefault);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2);
   }