From: Evan Shelhamer Date: Sat, 28 Jun 2014 01:36:48 +0000 (-0700) Subject: switch to unified virtual addressing CUDA memcpy X-Git-Tag: submit/tizen/20180823.020014~653^2~86^2~6 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9f74b6b129ab6ae161d569ce73d346553e04ec49;p=platform%2Fupstream%2Fcaffeonacl.git switch to unified virtual addressing CUDA memcpy Host / device copies are distinguished by the virtual address of the pointers instead of explicit memcpy modes. --- diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp index 21f51e8..1b2b65e 100644 --- a/matlab/caffe/matcaffe.cpp +++ b/matlab/caffe/matcaffe.cpp @@ -59,7 +59,7 @@ static mxArray* do_forward(const mxArray* const bottom) { break; case Caffe::GPU: cudaMemcpy(input_blobs[i]->mutable_gpu_data(), data_ptr, - sizeof(float) * input_blobs[i]->count(), cudaMemcpyHostToDevice); + sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault); break; default: LOG(FATAL) << "Unknown Caffe mode."; @@ -82,7 +82,7 @@ static mxArray* do_forward(const mxArray* const bottom) { break; case Caffe::GPU: cudaMemcpy(data_ptr, output_blobs[i]->gpu_data(), - sizeof(float) * output_blobs[i]->count(), cudaMemcpyDeviceToHost); + sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault); break; default: LOG(FATAL) << "Unknown Caffe mode."; @@ -109,7 +109,7 @@ static mxArray* do_backward(const mxArray* const top_diff) { break; case Caffe::GPU: cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), data_ptr, - sizeof(float) * output_blobs[i]->count(), cudaMemcpyHostToDevice); + sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault); break; default: LOG(FATAL) << "Unknown Caffe mode."; @@ -134,7 +134,7 @@ static mxArray* do_backward(const mxArray* const top_diff) { break; case Caffe::GPU: cudaMemcpy(data_ptr, input_blobs[i]->gpu_diff(), - sizeof(float) * input_blobs[i]->count(), cudaMemcpyDeviceToHost); + sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault); break; default: LOG(FATAL) << "Unknown Caffe mode."; @@ -211,7 +211,7 @@ static mxArray* do_get_weights() { break; case Caffe::GPU: CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(), - sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDeviceToHost)); + sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDefault)); break; default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index e603712..eb1075a 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -150,10 +150,10 @@ void Blob::CopyFrom(const Blob& source, bool copy_diff, bool reshape) { case Caffe::GPU: if (copy_diff) { CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(), - sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice)); + sizeof(Dtype) * count_, cudaMemcpyDefault)); } else { CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(), - sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice)); + sizeof(Dtype) * count_, cudaMemcpyDefault)); } break; case Caffe::CPU: diff --git a/src/caffe/layers/data_layer.cu b/src/caffe/layers/data_layer.cu index 2ff9a29..5a8bb0b 100644 --- a/src/caffe/layers/data_layer.cu +++ b/src/caffe/layers/data_layer.cu @@ -23,11 +23,11 @@ Dtype DataLayer::Forward_gpu(const vector*>& bottom, // Copy the data CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); if (output_labels_) { CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); } // Start a new prefetch thread CreatePrefetchThread(); diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu index b2b09ef..232f55f 100644 --- a/src/caffe/layers/hdf5_data_layer.cu +++ b/src/caffe/layers/hdf5_data_layer.cu @@ -44,12 +44,12 @@ Dtype HDF5DataLayer::Forward_gpu(const vector*>& bottom, &(*top)[0]->mutable_gpu_data()[i * data_count], &data_blob_.cpu_data()[current_row_ * data_count], sizeof(Dtype) * data_count, - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy( &(*top)[1]->mutable_gpu_data()[i * label_data_count], &label_blob_.cpu_data()[current_row_ * label_data_count], sizeof(Dtype) * label_data_count, - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); } return Dtype(0.); } diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu index 59505ee..19567c1 100644 --- a/src/caffe/layers/hdf5_output_layer.cu +++ b/src/caffe/layers/hdf5_output_layer.cu @@ -29,10 +29,10 @@ Dtype HDF5OutputLayer::Forward_gpu(const vector*>& bottom, for (int i = 0; i < bottom[0]->num(); ++i) { CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim], &bottom[0]->gpu_data()[i * data_datum_dim], - sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost)); + sizeof(Dtype) * data_datum_dim, cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim], &bottom[1]->gpu_data()[i * label_datum_dim], - sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost)); + sizeof(Dtype) * label_datum_dim, cudaMemcpyDefault)); } SaveBlobs(); return Dtype(0.); diff --git a/src/caffe/layers/image_data_layer.cu b/src/caffe/layers/image_data_layer.cu index 9804729..3f905bc 100644 --- a/src/caffe/layers/image_data_layer.cu +++ b/src/caffe/layers/image_data_layer.cu @@ -29,10 +29,10 @@ Dtype ImageDataLayer::Forward_gpu(const vector*>& bottom, // Copy the data CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); // Start a new prefetch thread CreatePrefetchThread(); return Dtype(0.); diff --git a/src/caffe/layers/window_data_layer.cu b/src/caffe/layers/window_data_layer.cu index bc49fef..5e8cdb7 100644 --- a/src/caffe/layers/window_data_layer.cu +++ b/src/caffe/layers/window_data_layer.cu @@ -30,10 +30,10 @@ Dtype WindowDataLayer::Forward_gpu(const vector*>& bottom, // Copy the data CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(), - cudaMemcpyHostToDevice)); + cudaMemcpyDefault)); // Start a new prefetch thread CreatePrefetchThread(); return Dtype(0.); diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index fec37d6..ec4c528 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -32,7 +32,7 @@ inline void SyncedMemory::to_cpu() { CaffeMallocHost(&cpu_ptr_, size_); own_cpu_data_ = true; } - CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault)); head_ = SYNCED; break; case HEAD_AT_CPU: @@ -52,7 +52,7 @@ inline void SyncedMemory::to_gpu() { if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); } - CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault)); head_ = SYNCED; break; case HEAD_AT_GPU: diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp index 7bbbbab..e741f0f 100644 --- a/src/caffe/test/test_syncedmem.cpp +++ b/src/caffe/test/test_syncedmem.cpp @@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) { // check if values are the same char* recovered_value = new char[10]; cudaMemcpy(reinterpret_cast(recovered_value), gpu_data, 10, - cudaMemcpyDeviceToHost); + cudaMemcpyDefault); for (int i = 0; i < mem.size(); ++i) { EXPECT_EQ((reinterpret_cast(recovered_value))[i], 1); } @@ -73,7 +73,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) { EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); // check if values are the same cudaMemcpy(reinterpret_cast(recovered_value), gpu_data, 10, - cudaMemcpyDeviceToHost); + cudaMemcpyDefault); for (int i = 0; i < mem.size(); ++i) { EXPECT_EQ((reinterpret_cast(recovered_value))[i], 2); }