Fix caffe/alt_fn lint errors.
authorJeff Donahue <jeff.donahue@gmail.com>
Wed, 6 Aug 2014 00:33:19 +0000 (17:33 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 12 Aug 2014 19:43:41 +0000 (12:43 -0700)
include/caffe/util/math_functions.hpp
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/syncedmem.cpp
src/caffe/test/test_syncedmem.cpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.cu

index 6004e93..90a1a86 100644 (file)
@@ -40,7 +40,7 @@ template <typename Dtype>
 void caffe_set(const int N, const Dtype alpha, Dtype *X);
 
 inline void caffe_memset(const size_t N, const int alpha, void* X) {
-  memset(X, alpha, N);
+  memset(X, alpha, N);  // NOLINT(caffe/alt_fn)
 }
 
 template <typename Dtype>
index 5bad5f6..938d843 100644 (file)
@@ -103,12 +103,11 @@ Dtype HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       }
       current_row_ = 0;
     }
-    memcpy(&(*top)[0]->mutable_cpu_data()[i * data_count],
-           &data_blob_.cpu_data()[current_row_ * data_count],
-           sizeof(Dtype) * data_count);
-    memcpy(&(*top)[1]->mutable_cpu_data()[i * label_data_count],
-            &label_blob_.cpu_data()[current_row_ * label_data_count],
-            sizeof(Dtype) * label_data_count);
+    caffe_copy(data_count, &data_blob_.cpu_data()[current_row_ * data_count],
+               &(*top)[0]->mutable_cpu_data()[i * data_count]);
+    caffe_copy(label_data_count,
+               &label_blob_.cpu_data()[current_row_ * label_data_count],
+               &(*top)[1]->mutable_cpu_data()[i * label_data_count]);
   }
   return Dtype(0.);
 }
index 7d25183..7617ccf 100644 (file)
@@ -22,7 +22,7 @@ inline void SyncedMemory::to_cpu() {
   switch (head_) {
   case UNINITIALIZED:
     CaffeMallocHost(&cpu_ptr_, size_);
-    memset(cpu_ptr_, 0, size_);
+    caffe_memset(size_, 0, cpu_ptr_);
     head_ = HEAD_AT_CPU;
     own_cpu_data_ = true;
     break;
@@ -49,7 +49,7 @@ inline void SyncedMemory::to_gpu() {
   switch (head_) {
   case UNINITIALIZED:
     CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
-    CUDA_CHECK(cudaMemset(gpu_ptr_, 0, size_));
+    caffe_gpu_memset(size_, 0, gpu_ptr_);
     head_ = HEAD_AT_GPU;
     break;
   case HEAD_AT_CPU:
index b658871..b946233 100644 (file)
@@ -55,14 +55,14 @@ TEST_F(SyncedMemoryTest, TestCPUWrite) {
   SyncedMemory mem(10);
   void* cpu_data = mem.mutable_cpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
-  memset(cpu_data, 1, mem.size());
+  caffe_memset(mem.size(), 1, cpu_data);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((static_cast<char*>(cpu_data))[i], 1);
   }
   // do another round
   cpu_data = mem.mutable_cpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
-  memset(cpu_data, 2, mem.size());
+  caffe_memset(mem.size(), 2, cpu_data);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((static_cast<char*>(cpu_data))[i], 2);
   }
@@ -74,7 +74,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   SyncedMemory mem(10);
   void* cpu_data = mem.mutable_cpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
-  memset(cpu_data, 1, mem.size());
+  caffe_memset(mem.size(), 1, cpu_data);
   const void* gpu_data = mem.gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
   // check if values are the same
@@ -86,7 +86,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
   // do another round
   cpu_data = mem.mutable_cpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
-  memset(cpu_data, 2, mem.size());
+  caffe_memset(mem.size(), 2, cpu_data);
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((static_cast<char*>(cpu_data))[i], 2);
   }
@@ -104,7 +104,7 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
   SyncedMemory mem(10);
   void* gpu_data = mem.mutable_gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU);
-  CUDA_CHECK(cudaMemset(gpu_data, 1, mem.size()));
+  caffe_gpu_memset(mem.size(), 1, gpu_data);
   const void* cpu_data = mem.cpu_data();
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((static_cast<const char*>(cpu_data))[i], 1);
@@ -113,7 +113,7 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
 
   gpu_data = mem.mutable_gpu_data();
   EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU);
-  CUDA_CHECK(cudaMemset(gpu_data, 2, mem.size()));
+  caffe_gpu_memset(mem.size(), 2, gpu_data);
   cpu_data = mem.cpu_data();
   for (int i = 0; i < mem.size(); ++i) {
     EXPECT_EQ((static_cast<const char*>(cpu_data))[i], 2);
index 974adf5..e10f019 100644 (file)
@@ -56,7 +56,7 @@ void caffe_axpy<double>(const int N, const double alpha, const double* X,
 template <typename Dtype>
 void caffe_set(const int N, const Dtype alpha, Dtype* Y) {
   if (alpha == 0) {
-    memset(Y, 0, sizeof(Dtype) * N);
+    memset(Y, 0, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)
     return;
   }
   for (int i = 0; i < N; ++i) {
@@ -87,12 +87,13 @@ void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
   if (X != Y) {
     if (Caffe::mode() == Caffe::GPU) {
 #ifndef CPU_ONLY
+      // NOLINT_NEXT_LINE(caffe/alt_fn)
       CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
 #else
       NO_GPU;
 #endif
     } else {
-      memcpy(Y, X, sizeof(Dtype) * N);
+      memcpy(Y, X, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)
     }
   }
 }
index cec051e..eacbb47 100644 (file)
@@ -78,7 +78,7 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
 
 void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
   if (X != Y) {
-    CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
+    CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));  // NOLINT(caffe/alt_fn)
   }
 }
 
@@ -152,7 +152,7 @@ __global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) {
 template <typename Dtype>
 void caffe_gpu_set(const int N, const Dtype alpha, Dtype* Y) {
   if (alpha == 0) {
-    CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N));
+    CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N));  // NOLINT(caffe/alt_fn)
     return;
   }
   // NOLINT_NEXT_LINE(whitespace/operators)