replace all memset with caffe_set() / caffe_gpu_set()
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 28 Jun 2014 08:52:49 +0000 (01:52 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 4 Jul 2014 00:20:28 +0000 (17:20 -0700)
...except for `SyncedMem` since it has no type.

src/caffe/layers/conv_layer.cpp
src/caffe/layers/conv_layer.cu
src/caffe/layers/lrn_layer.cpp
src/caffe/layers/multinomial_logistic_loss_layer.cpp
src/caffe/layers/window_data_layer.cpp
src/caffe/util/im2col.cpp
src/caffe/util/im2col.cu
src/caffe/util/math_functions.cu

index 67913bf..963dc68 100644 (file)
@@ -126,11 +126,11 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* weight = this->blobs_[0]->cpu_data();
   Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
-  memset(weight_diff, 0, sizeof(Dtype) * this->blobs_[0]->count());
+  caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
   Dtype* bias_diff = NULL;
   if (bias_term_) {
     bias_diff = this->blobs_[1]->mutable_cpu_diff();
-    memset(bias_diff, 0, sizeof(Dtype) * this->blobs_[1]->count());
+    caffe_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
   }
   const int weight_offset = M_ * K_;
   const int col_offset = K_ * N_;
index 71b00c9..59ec58d 100644 (file)
@@ -48,15 +48,13 @@ void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* weight = this->blobs_[0]->gpu_data();
   Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
-  CUDA_CHECK(cudaMemset(weight_diff, 0,
-      sizeof(Dtype) * this->blobs_[0]->count()));
+  caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
   Dtype* col_data = col_buffer_.mutable_gpu_data();
   Dtype* col_diff = col_buffer_.mutable_gpu_diff();
   Dtype* bias_diff = NULL;
   if (bias_term_) {
     bias_diff = this->blobs_[1]->mutable_gpu_diff();
-    CUDA_CHECK(cudaMemset(bias_diff, 0,
-        sizeof(Dtype) * this->blobs_[1]->count()));
+    caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
   }
   const int weight_offset = M_ * K_;
   const int col_offset = K_ * N_;
index a86c1d4..2bda043 100644 (file)
@@ -123,7 +123,7 @@ Dtype LRNLayer<Dtype>::CrossChannelForward_cpu(
   }
   Blob<Dtype> padded_square(1, channels_ + size_ - 1, height_, width_);
   Dtype* padded_square_data = padded_square.mutable_cpu_data();
-  memset(padded_square_data, 0, sizeof(Dtype) * padded_square.count());
+  caffe_set(padded_square.count(), Dtype(0), padded_square_data);
   Dtype alpha_over_size = alpha_ / size_;
   // go through the images
   for (int n = 0; n < num_; ++n) {
@@ -201,7 +201,7 @@ void LRNLayer<Dtype>::CrossChannelBackward_cpu(
   Dtype* accum_ratio_data = accum_ratio.mutable_cpu_data();
   // We hack a little bit by using the diff() to store an additional result
   Dtype* accum_ratio_times_bottom = accum_ratio.mutable_cpu_diff();
-  memset(padded_ratio_data, 0, sizeof(Dtype) * padded_ratio.count());
+  caffe_set(padded_ratio.count(), Dtype(0), padded_ratio_data);
   Dtype cache_ratio_value = 2. * alpha_ * beta_ / size_;
 
   caffe_powx<Dtype>(scale_.count(), scale_data, -beta_, bottom_diff);
@@ -220,7 +220,7 @@ void LRNLayer<Dtype>::CrossChannelBackward_cpu(
         scale_data + block_offset,
         padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad));
     // Now, compute the accumulated ratios and the bottom diff
-    memset(accum_ratio_data, 0, sizeof(Dtype) * accum_ratio.count());
+    caffe_set(accum_ratio.count(), Dtype(0), accum_ratio_data);
     for (int c = 0; c < size_ - 1; ++c) {
       caffe_axpy<Dtype>(height_ * width_, 1.,
           padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
index dd5cae4..8687784 100644 (file)
@@ -55,7 +55,7 @@ void MultinomialLogisticLossLayer<Dtype>::Backward_cpu(
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     int num = (*bottom)[0]->num();
     int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
-    memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
+    caffe_set((*bottom)[0]->count(), Dtype(0), bottom_diff);
     for (int i = 0; i < num; ++i) {
       int label = static_cast<int>(bottom_label[i]);
       Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
index fd4860f..5dbdff3 100644 (file)
@@ -59,7 +59,7 @@ void* WindowDataLayerPrefetch(void* layer_pointer) {
   bool use_square = (crop_mode == "square") ? true : false;
 
   // zero out batch
-  memset(top_data, 0, sizeof(Dtype)*layer->prefetch_data_->count());
+  caffe_set(layer->prefetch_data_->count(), Dtype(0), top_data);
 
   const int num_fg = static_cast<int>(static_cast<float>(batch_size)
       * fg_fraction);
index 037410e..ce4e188 100644 (file)
@@ -5,6 +5,7 @@
 #include <cstring>
 
 #include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
 
 namespace caffe {
 
@@ -45,7 +46,7 @@ template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int ksize, const int pad,
     const int stride, Dtype* data_im) {
-  memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+  caffe_set(height * width * channels, Dtype(0), data_im);
   int height_col = (height + 2 * pad - ksize) / stride + 1;
   int width_col = (width + 2 * pad - ksize) / stride + 1;
   int channels_col = channels * ksize * ksize;
index ec4465e..79faa6c 100644 (file)
@@ -106,8 +106,6 @@ template <typename Dtype>
 void col2im_gpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int ksize, const int pad,
     const int stride, Dtype* data_im) {
-  // CUDA_CHECK(cudaMemset(data_im, 0,
-  //            sizeof(Dtype) * height * width * channels));
   int height_col = (height + 2 * pad - ksize) / stride + 1;
   int width_col = (width + 2 * pad - ksize) / stride + 1;
   int num_kernels = channels * height * width;
index 63c8fac..849e53b 100644 (file)
@@ -20,27 +20,20 @@ __global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) {
   }
 }
 
-template <>
-void caffe_gpu_set(const int N, const float alpha, float* Y) {
+template <typename Dtype>
+void caffe_gpu_set(const int N, const Dtype alpha, Dtype* Y) {
   if (alpha == 0) {
-    CUDA_CHECK(cudaMemset(Y, 0, sizeof(float) * N));
+    CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N));
     return;
   }
   // NOLINT_NEXT_LINE(whitespace/operators)
-  set_kernel<float><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+  set_kernel<Dtype><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
       N, alpha, Y);
 }
 
-template <>
-void caffe_gpu_set(const int N, const double alpha, double* Y) {
-  if (alpha == 0) {
-    CUDA_CHECK(cudaMemset(Y, 0, sizeof(double) * N));
-    return;
-  }
-  // NOLINT_NEXT_LINE(whitespace/operators)
-  set_kernel<double><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
-      N, alpha, Y);
-}
+template void caffe_gpu_set<int>(const int N, const int alpha, int* Y);
+template void caffe_gpu_set<float>(const int N, const float alpha, float* Y);
+template void caffe_gpu_set<double>(const int N, const double alpha, double* Y);
 
 template <typename Dtype>
 __global__ void add_scalar_kernel(const int n, const Dtype alpha, Dtype* y) {