replace all memcpy by caffe_copy
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 28 Jun 2014 02:51:33 +0000 (19:51 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 4 Jul 2014 00:14:12 +0000 (17:14 -0700)
24 files changed:
include/caffe/blob.hpp
include/caffe/syncedmem.hpp
matlab/caffe/matcaffe.cpp
src/caffe/blob.cpp
src/caffe/layers/concat_layer.cu
src/caffe/layers/data_layer.cu
src/caffe/layers/dropout_layer.cu
src/caffe/layers/eltwise_layer.cu
src/caffe/layers/hdf5_data_layer.cu
src/caffe/layers/hdf5_output_layer.cpp
src/caffe/layers/hdf5_output_layer.cu
src/caffe/layers/image_data_layer.cu
src/caffe/layers/power_layer.cu
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
src/caffe/layers/softmax_layer.cpp
src/caffe/layers/softmax_layer.cu
src/caffe/layers/softmax_loss_layer.cpp
src/caffe/layers/window_data_layer.cu
src/caffe/solver.cpp
src/caffe/syncedmem.cpp
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_math_functions.cpp
src/caffe/test/test_syncedmem.cpp
src/caffe/test/test_util_blas.cpp

index c04375a..bbea86a 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe/common.hpp"
 #include "caffe/syncedmem.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/math_functions.hpp"
 
 namespace caffe {
 
index bed55c3..2b7f349 100644 (file)
@@ -6,6 +6,7 @@
 #include <cstdlib>
 
 #include "caffe/common.hpp"
+#include "caffe/util/math_functions.hpp"
 
 namespace caffe {
 
index 1b2b65e..957ebea 100644 (file)
@@ -54,12 +54,12 @@ static mxArray* do_forward(const mxArray* const bottom) {
         reinterpret_cast<const float* const>(mxGetPr(elem));
     switch (Caffe::mode()) {
     case Caffe::CPU:
-      memcpy(input_blobs[i]->mutable_cpu_data(), data_ptr,
-          sizeof(float) * input_blobs[i]->count());
+      caffe_copy(input_blobs[i]->count(), data_ptr,
+          input_blobs[i]->mutable_cpu_data());
       break;
     case Caffe::GPU:
-      cudaMemcpy(input_blobs[i]->mutable_gpu_data(), data_ptr,
-          sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault);
+      caffe_copy(input_blobs[i]->count(), data_ptr,
+          input_blobs[i]->mutable_gpu_data());
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -77,12 +77,12 @@ static mxArray* do_forward(const mxArray* const bottom) {
     float* data_ptr = reinterpret_cast<float*>(mxGetPr(mx_blob));
     switch (Caffe::mode()) {
     case Caffe::CPU:
-      memcpy(data_ptr, output_blobs[i]->cpu_data(),
-          sizeof(float) * output_blobs[i]->count());
+      caffe_copy(output_blobs[i]->count(), output_blobs[i]->cpu_data(),
+          data_ptr);
       break;
     case Caffe::GPU:
-      cudaMemcpy(data_ptr, output_blobs[i]->gpu_data(),
-          sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault);
+      caffe_copy(output_blobs[i]->count(), output_blobs[i]->gpu_data(),
+          data_ptr);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -104,12 +104,12 @@ static mxArray* do_backward(const mxArray* const top_diff) {
         reinterpret_cast<const float* const>(mxGetPr(elem));
     switch (Caffe::mode()) {
     case Caffe::CPU:
-      memcpy(output_blobs[i]->mutable_cpu_diff(), data_ptr,
-        sizeof(float) * output_blobs[i]->count());
+      caffe_copy(output_blobs[i]->count(), data_ptr,
+          output_blobs[i]->mutable_cpu_diff());
       break;
     case Caffe::GPU:
-      cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), data_ptr,
-        sizeof(float) * output_blobs[i]->count(), cudaMemcpyDefault);
+      caffe_copy(output_blobs[i]->count(), data_ptr,
+          output_blobs[i]->mutable_gpu_diff());
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -129,12 +129,10 @@ static mxArray* do_backward(const mxArray* const top_diff) {
     float* data_ptr = reinterpret_cast<float*>(mxGetPr(mx_blob));
     switch (Caffe::mode()) {
     case Caffe::CPU:
-      memcpy(data_ptr, input_blobs[i]->cpu_diff(),
-          sizeof(float) * input_blobs[i]->count());
+      caffe_copy(input_blobs[i]->count(), input_blobs[i]->cpu_diff(), data_ptr);
       break;
     case Caffe::GPU:
-      cudaMemcpy(data_ptr, input_blobs[i]->gpu_diff(),
-          sizeof(float) * input_blobs[i]->count(), cudaMemcpyDefault);
+      caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr);
       break;
     default:
       LOG(FATAL) << "Unknown Caffe mode.";
@@ -206,12 +204,12 @@ static mxArray* do_get_weights() {
 
         switch (Caffe::mode()) {
         case Caffe::CPU:
-          memcpy(weights_ptr, layer_blobs[j]->cpu_data(),
-              sizeof(float) * layer_blobs[j]->count());
+          caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->cpu_data(),
+              weights_ptr);
           break;
         case Caffe::GPU:
-          CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(),
-              sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDefault));
+          caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->gpu_data(),
+              weights_ptr);
           break;
         default:
           LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
index eb1075a..69ff49e 100644 (file)
@@ -149,20 +149,20 @@ void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
   switch (Caffe::mode()) {
   case Caffe::GPU:
     if (copy_diff) {
-      CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
-          sizeof(Dtype) * count_, cudaMemcpyDefault));
+      caffe_copy(count_, source.gpu_diff(),
+          reinterpret_cast<Dtype*>(diff_->mutable_gpu_data()));
     } else {
-      CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
-          sizeof(Dtype) * count_, cudaMemcpyDefault));
+      caffe_copy(count_, source.gpu_data(),
+          reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));
     }
     break;
   case Caffe::CPU:
     if (copy_diff) {
-      memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
-          sizeof(Dtype) * count_);
+      caffe_copy(count_, source.cpu_diff(),
+          reinterpret_cast<Dtype*>(diff_->mutable_cpu_data()));
     } else {
-      memcpy(data_->mutable_cpu_data(), source.cpu_data(),
-        sizeof(Dtype) * count_);
+      caffe_copy(count_, source.cpu_data(),
+          reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));
     }
     break;
   default:
index ca0cf0c..2643d74 100644 (file)
@@ -16,7 +16,7 @@ Dtype ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     int offset_num = 0;
     for (int i = 0; i < bottom.size(); ++i) {
       const Dtype* bottom_data = bottom[i]->gpu_data();
-      caffe_gpu_copy(bottom[i]->count(), bottom_data,
+      caffe_copy(bottom[i]->count(), bottom_data,
         top_data + (*top)[0]->offset(offset_num));
       offset_num += bottom[i]->num();
     }
@@ -27,7 +27,7 @@ Dtype ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       int num_elem =
         bottom[i]->channels() * bottom[i]->height() * bottom[i]->width();
       for (int n = 0; n < num_; ++n) {
-        caffe_gpu_copy(num_elem, bottom_data+bottom[i]->offset(n),
+        caffe_copy(num_elem, bottom_data+bottom[i]->offset(n),
           top_data + (*top)[0]->offset(n, offset_channel));
       }
       offset_channel += bottom[i]->channels();
@@ -49,7 +49,7 @@ void ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       Blob<Dtype>* blob = (*bottom)[i];
       if (propagate_down[i]) {
         Dtype* bottom_diff = blob->mutable_gpu_diff();
-        caffe_gpu_copy(blob->count(), top_diff + top[0]->offset(offset_num),
+        caffe_copy(blob->count(), top_diff + top[0]->offset(offset_num),
                        bottom_diff);
       }
       offset_num += blob->num();
@@ -62,7 +62,7 @@ void ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         Dtype* bottom_diff = blob->mutable_gpu_diff();
         int num_elem = blob->channels()*blob->height()*blob->width();
         for (int n = 0; n < num_; ++n) {
-          caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
+          caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
                          bottom_diff + blob->offset(n));
         }
       }
index 5a8bb0b..40316a1 100644 (file)
@@ -21,13 +21,11 @@ Dtype DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // First, join the thread
   JoinPrefetchThread();
   // Copy the data
-  CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
-      prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyDefault));
+  caffe_copy(prefetch_data_->count(), prefetch_data_->cpu_data(),
+      (*top)[0]->mutable_gpu_data());
   if (output_labels_) {
-    CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
-        prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-        cudaMemcpyDefault));
+    caffe_copy(prefetch_label_->count(), prefetch_label_->cpu_data(),
+        (*top)[1]->mutable_gpu_data());
   }
   // Start a new prefetch thread
   CreatePrefetchThread();
index 225e091..11a4263 100644 (file)
@@ -40,7 +40,7 @@ Dtype DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
         count, bottom_data, mask, uint_thres_, scale_, top_data);
     CUDA_POST_KERNEL_CHECK;
   } else {
-    caffe_gpu_copy(count, bottom_data, top_data);
+    caffe_copy(count, bottom_data, top_data);
   }
   return Dtype(0);
 }
index 3860944..99c14fe 100644 (file)
@@ -51,7 +51,7 @@ void EltwiseLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         break;
       case EltwiseParameter_EltwiseOp_SUM:
         if (coeffs_[i] == Dtype(1.)) {
-          caffe_gpu_copy(count, top_diff, bottom_diff);
+          caffe_copy(count, top_diff, bottom_diff);
         } else {
           caffe_gpu_scale(count, coeffs_[i], top_diff, bottom_diff);
         }
index 232f55f..3c27f37 100644 (file)
@@ -40,16 +40,12 @@ Dtype HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       }
       current_row_ = 0;
     }
-    CUDA_CHECK(cudaMemcpy(
-            &(*top)[0]->mutable_gpu_data()[i * data_count],
-            &data_blob_.cpu_data()[current_row_ * data_count],
-            sizeof(Dtype) * data_count,
-            cudaMemcpyDefault));
-    CUDA_CHECK(cudaMemcpy(
-            &(*top)[1]->mutable_gpu_data()[i * label_data_count],
-            &label_blob_.cpu_data()[current_row_ * label_data_count],
-            sizeof(Dtype) * label_data_count,
-            cudaMemcpyDefault));
+    caffe_copy(data_count,
+        &data_blob_.cpu_data()[current_row_ * data_count],
+        &(*top)[0]->mutable_gpu_data()[i * data_count]);
+    caffe_copy(label_data_count,
+        &label_blob_.cpu_data()[current_row_ * label_data_count],
+        &(*top)[1]->mutable_gpu_data()[i * label_data_count]);
   }
   return Dtype(0.);
 }
index 3a513b9..8307ad7 100644 (file)
@@ -54,12 +54,10 @@ Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
 
   for (int i = 0; i < bottom[0]->num(); ++i) {
-    memcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
-           &bottom[0]->cpu_data()[i * data_datum_dim],
-           sizeof(Dtype) * data_datum_dim);
-    memcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
-           &bottom[1]->cpu_data()[i * label_datum_dim],
-           sizeof(Dtype) * label_datum_dim);
+    caffe_copy(data_datum_dim, &bottom[0]->cpu_data()[i * data_datum_dim],
+        &data_blob_.mutable_cpu_data()[i * data_datum_dim]);
+    caffe_copy(label_datum_dim, &bottom[0]->cpu_data()[i * label_datum_dim],
+        &label_blob_.mutable_cpu_data()[i * label_datum_dim]);
   }
   SaveBlobs();
   return Dtype(0.);
index 19567c1..744b8fe 100644 (file)
@@ -27,12 +27,10 @@ Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
 
   for (int i = 0; i < bottom[0]->num(); ++i) {
-    CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
-           &bottom[0]->gpu_data()[i * data_datum_dim],
-           sizeof(Dtype) * data_datum_dim, cudaMemcpyDefault));
-    CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
-           &bottom[1]->gpu_data()[i * label_datum_dim],
-           sizeof(Dtype) * label_datum_dim, cudaMemcpyDefault));
+    caffe_copy(data_datum_dim, &bottom[0]->gpu_data()[i * data_datum_dim],
+        &data_blob_.mutable_cpu_data()[i * data_datum_dim]);
+    caffe_copy(label_datum_dim, &bottom[0]->gpu_data()[i * label_datum_dim],
+        &label_blob_.mutable_cpu_data()[i * label_datum_dim]);
   }
   SaveBlobs();
   return Dtype(0.);
index 3f905bc..dd5bdbc 100644 (file)
@@ -27,12 +27,10 @@ Dtype ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // First, join the thread
   JoinPrefetchThread();
   // Copy the data
-  CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
-      prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyDefault));
-  CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
-      prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-      cudaMemcpyDefault));
+  caffe_copy(prefetch_data_->count(), prefetch_data_->cpu_data(),
+      (*top)[0]->mutable_gpu_data());
+  caffe_copy(prefetch_label_->count(), prefetch_label_->cpu_data(),
+      (*top)[1]->mutable_gpu_data());
   // Start a new prefetch thread
   CreatePrefetchThread();
   return Dtype(0.);
index 6d69963..e7f9831 100644 (file)
@@ -23,7 +23,7 @@ Dtype PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     return Dtype(0);
   }
   const Dtype* bottom_data = bottom[0]->gpu_data();
-  caffe_gpu_copy(count, bottom_data, top_data);
+  caffe_copy(count, bottom_data, top_data);
   if (scale_ != Dtype(1)) {
     caffe_gpu_scal(count, scale_, top_data);
   }
@@ -68,7 +68,7 @@ void PowerLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         caffe_gpu_div(count, top_data, bottom_data, bottom_diff);
         caffe_gpu_scal(count, power_, bottom_diff);
       } else {
-        caffe_gpu_copy(count, bottom_data, bottom_diff);
+        caffe_copy(count, bottom_data, bottom_diff);
         if (scale_ != Dtype(1)) {
           caffe_gpu_scal(count, scale_, bottom_diff);
         }
index 8f72758..0c858cd 100644 (file)
@@ -50,7 +50,7 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
     const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
     const Dtype* target = (*bottom)[1]->gpu_data();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
-    caffe_gpu_copy(count, sigmoid_output_data, bottom_diff);
+    caffe_copy(count, sigmoid_output_data, bottom_diff);
     caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
     // Scale down gradient
     caffe_gpu_scal(count, Dtype(1) / num, bottom_diff);
index 57847d0..5d60d9d 100644 (file)
@@ -34,7 +34,7 @@ Dtype SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* scale_data = scale_.mutable_cpu_data();
   int num = bottom[0]->num();
   int dim = bottom[0]->count() / bottom[0]->num();
-  memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
+  caffe_copy(bottom[0]->count(), bottom_data, top_data);
   // we need to subtract the max to avoid numerical issues, compute the exp,
   // and then normalize.
   for (int i = 0; i < num; ++i) {
@@ -68,7 +68,7 @@ void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   Dtype* scale_data = scale_.mutable_cpu_data();
   int num = top[0]->num();
   int dim = top[0]->count() / top[0]->num();
-  memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
+  caffe_copy(top[0]->count(), top_diff, bottom_diff);
   // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
   for (int i = 0; i < num; ++i) {
     scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
index 5ec4412..ceeaff5 100644 (file)
@@ -50,7 +50,7 @@ Dtype SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* scale_data = scale_.mutable_gpu_data();
   int num = bottom[0]->num();
   int dim = bottom[0]->count() / bottom[0]->num();
-  caffe_gpu_copy(bottom[0]->count(), bottom_data, top_data);
+  caffe_copy(bottom[0]->count(), bottom_data, top_data);
   // we need to subtract the max to avoid numerical issues, compute the exp,
   // and then normalize.
   // Compute max
@@ -84,7 +84,7 @@ void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
   int num = top[0]->num();
   int dim = top[0]->count() / top[0]->num();
-  caffe_gpu_copy(top[0]->count(), top_diff, bottom_diff);
+  caffe_copy(top[0]->count(), top_diff, bottom_diff);
   // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
   // cuda dot returns the result to cpu, so we temporarily change the pointer
   // mode
index 1a3601a..37c5ebc 100644 (file)
@@ -66,7 +66,7 @@ void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   if (propagate_down[0]) {
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const Dtype* prob_data = prob_.cpu_data();
-    memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count());
+    caffe_copy(prob_.count(), prob_data, bottom_diff);
     const Dtype* label = (*bottom)[1]->cpu_data();
     int num = prob_.num();
     int dim = prob_.count() / num;
index 5e8cdb7..ca664fc 100644 (file)
@@ -28,12 +28,10 @@ Dtype WindowDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // First, join the thread
   JoinPrefetchThread();
   // Copy the data
-  CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
-      prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
-      cudaMemcpyDefault));
-  CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
-      prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-      cudaMemcpyDefault));
+  caffe_copy(prefetch_data_->count(), prefetch_data_->cpu_data(),
+      (*top)[0]->mutable_gpu_data());
+  caffe_copy(prefetch_label_->count(), prefetch_label_->cpu_data(),
+      (*top)[1]->mutable_gpu_data());
   // Start a new prefetch thread
   CreatePrefetchThread();
   return Dtype(0.);
index 7696181..ca1d925 100644 (file)
@@ -310,7 +310,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
             history_[param_id]->mutable_gpu_data());
       }
       // copy
-      caffe_gpu_copy(net_params[param_id]->count(),
+      caffe_copy(net_params[param_id]->count(),
           history_[param_id]->gpu_data(),
           net_params[param_id]->mutable_gpu_diff());
     }
index ec4c528..011d359 100644 (file)
@@ -6,6 +6,7 @@
 
 #include "caffe/common.hpp"
 #include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
 
 namespace caffe {
 
@@ -32,7 +33,7 @@ inline void SyncedMemory::to_cpu() {
       CaffeMallocHost(&cpu_ptr_, size_);
       own_cpu_data_ = true;
     }
-    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
+    caffe_copy(size_, gpu_ptr_, cpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_CPU:
@@ -52,7 +53,7 @@ inline void SyncedMemory::to_gpu() {
     if (gpu_ptr_ == NULL) {
       CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     }
-    CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault));
+    caffe_copy(size_, cpu_ptr_, gpu_ptr_);
     head_ = SYNCED;
     break;
   case HEAD_AT_GPU:
index ff104b9..2d551f8 100644 (file)
@@ -230,7 +230,7 @@ Dtype GradientChecker<Dtype>::GetObjAndGradient(vector<Blob<Dtype>*>* top,
         loss += top_blob_data[j] * top_blob_data[j];
       }
       // set the diff: simply the data.
-      memcpy(top_blob_diff, top_blob_data, sizeof(Dtype) * top_blob->count());
+      caffe_copy(top_blob->count(), top_blob_data, top_blob_diff);
     }
     loss /= 2.;
   } else {
@@ -238,7 +238,7 @@ Dtype GradientChecker<Dtype>::GetObjAndGradient(vector<Blob<Dtype>*>* top,
     for (int i = 0; i < top->size(); ++i) {
       Blob<Dtype>* top_blob = (*top)[i];
       Dtype* top_blob_diff = top_blob->mutable_cpu_diff();
-      memset(top_blob_diff, 0, sizeof(Dtype) * top_blob->count());
+      caffe_set(top_blob->count(), Dtype(0), top_blob_diff);
     }
     loss = (*top)[top_id]->cpu_data()[top_data_id];
     (*top)[top_id]->mutable_cpu_diff()[top_data_id] = 1.;
index d026576..941d8b9 100644 (file)
@@ -219,7 +219,7 @@ TYPED_TEST(MathFunctionsTest, TestCopyGPU) {
   const int n = this->blob_bottom_->count();
   const TypeParam* bottom_data = this->blob_bottom_->gpu_data();
   TypeParam* top_data = this->blob_top_->mutable_gpu_data();
-  caffe_gpu_copy(n, bottom_data, top_data);
+  caffe_copy(n, bottom_data, top_data);
   bottom_data = this->blob_bottom_->cpu_data();
   top_data = this->blob_top_->mutable_cpu_data();
   for (int i = 0; i < n; ++i) {
index e741f0f..a7a5131 100644 (file)
@@ -7,6 +7,7 @@
 #include "gtest/gtest.h"
 #include "caffe/common.hpp"
 #include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"
 
@@ -57,8 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
   char* recovered_value = new char[10];
-  cudaMemcpy(reinterpret_cast<void*>(recovered_value), gpu_data, 10,
-             cudaMemcpyDefault);
+  caffe_copy(size_t(10), gpu_data, reinterpret_cast<void*>(recovered_value));
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1);
   }
@@ -72,8 +72,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   gpu_data = mem.gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
-  cudaMemcpy(reinterpret_cast<void*>(recovered_value), gpu_data, 10,
-             cudaMemcpyDefault);
+  caffe_copy(size_t(10), gpu_data, reinterpret_cast<void*>(recovered_value));
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2);
   }
index 2e4c679..5b4c48e 100644 (file)
@@ -30,8 +30,8 @@ TYPED_TEST(GemmTest, TestGemm) {
   TypeParam A_reshape_data[6] = {1, 4, 2, 5, 3, 6};
   TypeParam B_reshape_data[12] = {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12};
   TypeParam result[8] = {38, 44, 50, 56, 83, 98, 113, 128};
-  memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam));
-  memcpy(B.mutable_cpu_data(), data, 12 * sizeof(TypeParam));
+  caffe_copy(6, data, A.mutable_cpu_data());
+  caffe_copy(12, data, B.mutable_cpu_data());
 
   if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) {
     // [1, 2, 3; 4 5 6] * [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12];
@@ -48,7 +48,7 @@ TYPED_TEST(GemmTest, TestGemm) {
 
     // Test when we have a transposed A
     A.Reshape(1, 1, 3, 2);
-    memcpy(A.mutable_cpu_data(), A_reshape_data, 6 * sizeof(TypeParam));
+    caffe_copy(6, A_reshape_data, A.mutable_cpu_data());
     caffe_cpu_gemm<TypeParam>(CblasTrans, CblasNoTrans, 2, 4, 3, 1.,
         A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
     for (int i = 0; i < 8; ++i) {
@@ -62,7 +62,7 @@ TYPED_TEST(GemmTest, TestGemm) {
 
     // Test when we have a transposed A and a transposed B too
     B.Reshape(1, 1, 4, 3);
-    memcpy(B.mutable_cpu_data(), B_reshape_data, 12 * sizeof(TypeParam));
+    caffe_copy(12, B_reshape_data, B.mutable_cpu_data());
     caffe_cpu_gemm<TypeParam>(CblasTrans, CblasTrans, 2, 4, 3, 1.,
         A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
     for (int i = 0; i < 8; ++i) {
@@ -76,7 +76,7 @@ TYPED_TEST(GemmTest, TestGemm) {
 
     // Test when we have a transposed B
     A.Reshape(1, 1, 2, 3);
-    memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam));
+    caffe_copy(6, data, A.mutable_cpu_data());
     caffe_cpu_gemm<TypeParam>(CblasNoTrans, CblasTrans, 2, 4, 3, 1.,
         A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
     for (int i = 0; i < 8; ++i) {
@@ -100,8 +100,8 @@ TYPED_TEST(GemmTest, TestGemv) {
   TypeParam data[6] = {1, 2, 3, 4, 5, 6};
   TypeParam result_2[2] = {14, 32};
   TypeParam result_3[3] = {9, 12, 15};
-  memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam));
-  memcpy(x.mutable_cpu_data(), data, 3 * sizeof(TypeParam));
+  caffe_copy(6, data, A.mutable_cpu_data());
+  caffe_copy(3, data, x.mutable_cpu_data());
 
   if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) {
     caffe_cpu_gemv<TypeParam>(CblasNoTrans, 2, 3, 1., A.cpu_data(),
@@ -116,7 +116,7 @@ TYPED_TEST(GemmTest, TestGemv) {
     }
 
     // Test transpose case
-    memcpy(y.mutable_cpu_data(), data, 2 * sizeof(TypeParam));
+    caffe_copy(2, data, y.mutable_cpu_data());
     caffe_cpu_gemv<TypeParam>(CblasTrans, 2, 3, 1., A.cpu_data(),
         y.cpu_data(), 0., x.mutable_cpu_data());
     for (int i = 0; i < 3; ++i) {