more cpplint
authorYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 23:16:35 +0000 (16:16 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 23:16:35 +0000 (16:16 -0700)
src/caffe/layers/conv_layer.cpp
src/caffe/layers/im2col_layer.cpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/lrn_layer.cpp
src/caffe/layers/neuron_layer.cpp
src/caffe/layers/pooling_layer.cpp

index f5e0a5b..849e106 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
+#include <vector>
+
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/im2col.hpp"
@@ -31,7 +33,7 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(NUM_OUTPUT_ % GROUP_, 0)
       << "Number of output should be multiples of group.";
   biasterm_ = this->layer_param_.biasterm();
-  // Figure out the dimensions for individual gemms. 
+  // Figure out the dimensions for individual gemms.
   M_ = NUM_OUTPUT_ / GROUP_;
   K_ = CHANNELS_ * KSIZE_ * KSIZE_ / GROUP_;
   N_ = height_out * width_out;
@@ -54,7 +56,8 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
         GetFiller<Dtype>(this->layer_param_.bias_filler()));
     bias_filler->Fill(&this->blobs_[1]);
     bias_multiplier_.reset(new SyncedMemory(N_ * sizeof(Dtype)));
-    Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
+    Dtype* bias_multiplier_data =
+        reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
     for (int i = 0; i < N_; ++i) {
         bias_multiplier_data[i] = 1.;
     }
@@ -86,8 +89,8 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     if (biasterm_) {
       caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
           N_, 1, (Dtype)1., this->blobs_[1].cpu_data(),
-          (Dtype*)bias_multiplier_->cpu_data(), (Dtype)1.,
-          top_data + (*top)[0]->offset(n));
+          reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
+          (Dtype)1., top_data + (*top)[0]->offset(n));
     }
   }
 }
@@ -116,8 +119,8 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     if (biasterm_) {
       caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
           N_, 1, (Dtype)1., this->blobs_[1].gpu_data(),
-          (Dtype*)bias_multiplier_->gpu_data(), (Dtype)1.,
-          top_data + (*top)[0]->offset(n));
+          reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
+          (Dtype)1., top_data + (*top)[0]->offset(n));
     }
   }
 }
@@ -140,8 +143,9 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1].count());
     for (int n = 0; n < NUM_; ++n) {
       caffe_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
-        1., top_diff + top[0]->offset(n),
-        (Dtype*)bias_multiplier_->cpu_data(), 1., bias_diff);
+          1., top_diff + top[0]->offset(n),
+          reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), 1.,
+          bias_diff);
     }
   }
 
@@ -196,8 +200,9 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         sizeof(Dtype) * this->blobs_[1].count()));
     for (int n = 0; n < NUM_; ++n) {
       caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
-        1., top_diff + top[0]->offset(n),
-        (Dtype*)bias_multiplier_->gpu_data(), 1., bias_diff);
+          1., top_diff + top[0]->offset(n),
+          reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
+          1., bias_diff);
     }
   }
 
index 30e80bf..976c844 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
+#include <vector>
+
 #include "caffe/layer.hpp"
 #include "caffe/util/im2col.hpp"
 #include "caffe/vision_layers.hpp"
index 9736faf..b39b568 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
+#include <vector>
+
 #include <mkl.h>
 #include <cublas_v2.h>
 
@@ -59,8 +61,8 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
-        (Dtype*)bias_multiplier_->cpu_data(), this->blobs_[1].cpu_data(),
-        (Dtype)1., top_data);
+        reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
+        this->blobs_[1].cpu_data(), (Dtype)1., top_data);
   }
 }
 
@@ -76,7 +78,7 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
-        (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0.,
+        reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
         this->blobs_[1].mutable_cpu_diff());
   }
   if (propagate_down) {
@@ -98,8 +100,8 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
-        (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(),
-        (Dtype)1., top_data);
+        reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
+        this->blobs_[1].gpu_data(), (Dtype)1., top_data);
   }
 }
 
@@ -115,8 +117,8 @@ Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
-        (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0.,
-        this->blobs_[1].mutable_gpu_diff());
+        reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
+        (Dtype)0., this->blobs_[1].mutable_gpu_diff());
   }
   if (propagate_down) {
     // Gradient with respect to bottom data
index 1c85a67..337b77b 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
+#include <vector>
+
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/math_functions.hpp"
@@ -11,7 +13,7 @@ void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 1) <<
       "Local Response Normalization Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << 
+  CHECK_EQ(top->size(), 1) <<
       "Local Response Normalization Layer takes a single blob as output.";
   num_ = bottom[0]->num();
   channels_ = bottom[0]->channels();
@@ -119,7 +121,7 @@ Dtype LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
           bottom_data + top[0]->offset(n, c),
           accum_ratio_data, accum_ratio_times_bottom);
       caffe_axpy<Dtype>(height_ * width_, -cache_ratio_value,
-          accum_ratio_times_bottom, bottom_diff + top[0]->offset(n,c));
+          accum_ratio_times_bottom, bottom_diff + top[0]->offset(n, c));
       caffe_axpy<Dtype>(height_ * width_, -1.,
           padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
     }
index 1acbfae..fcf6ff5 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
+#include <vector>
+
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 
index 6806268..e8bc9b7 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <algorithm>
 #include <cfloat>
+#include <vector>
+
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/util/math_functions.hpp"
@@ -23,14 +25,16 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHANNELS_ = bottom[0]->channels();
   HEIGHT_ = bottom[0]->height();
   WIDTH_ = bottom[0]->width();
-  POOLED_HEIGHT_ = int(ceil(float(HEIGHT_ - KSIZE_) / STRIDE_)) + 1;
-  POOLED_WIDTH_ = int(ceil(float(WIDTH_ - KSIZE_) / STRIDE_)) + 1;
+  POOLED_HEIGHT_ = static_cast<int>(
+      ceil(static_cast<float>(HEIGHT_ - KSIZE_) / STRIDE_)) + 1;
+  POOLED_WIDTH_ = static_cast<int>(
+      ceil(static_cast<float>(WIDTH_ - KSIZE_) / STRIDE_)) + 1;
   (*top)[0]->Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
       POOLED_WIDTH_);
 };
 
-
-// TODO: Is there a faster way to do pooling in the channel-first case?
+// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
+// case?
 template <typename Dtype>
 void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
@@ -96,7 +100,7 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       }
     }
     // Our implementation simply divides the pooled values by KSIZE^2,
-    // regardless of the actual pooling region. This would allow one to not 
+    // regardless of the actual pooling region. This would allow one to not
     // trust too much on the border pooling regions, but I am not sure what
     // benefit / harm it would bring to the actual code.
     caffe_scal<Dtype>(top_count, Dtype(1.) / KSIZE_ / KSIZE_,
@@ -176,7 +180,7 @@ Dtype PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       }
     }
     // Our implementation simply divides the pooled values by KSIZE^2,
-    // regardless of the actual pooling region. This would allow one to not 
+    // regardless of the actual pooling region. This would allow one to not
     // trust too much on the border pooling regions, but I am not sure what
     // benefit / harm it would bring to the actual code.
     caffe_scal<Dtype>((*bottom)[0]->count(), Dtype(1.) / KSIZE_ / KSIZE_,