pylint and code cleaning
authorYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 23:00:47 +0000 (16:00 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 23 Sep 2013 23:00:47 +0000 (16:00 -0700)
src/caffe/blob.cpp
src/caffe/blob.hpp
src/caffe/common.hpp
src/caffe/filler.hpp
src/caffe/layer.hpp
src/caffe/layer_factory.hpp
src/caffe/syncedmem.cpp
src/caffe/syncedmem.hpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp

index 4bba3f1..d0e47da 100644 (file)
@@ -35,13 +35,16 @@ Blob<Dtype>::Blob(const Blob<Dtype>& source) {
   if (source.count() == 0) {
     Blob();
   } else {
-    Reshape(source.num(), source.channels(), source.height(), source.width());
+    Reshape(source.num(), source.channels(), source.height(),
+        source.width());
     // create the synced memories.
     data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
     diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
     // Copy the data.
-    memcpy(data_->mutable_cpu_data(), source.cpu_data(), count_ * sizeof(Dtype));
-    memcpy(diff_->mutable_cpu_data(), source.cpu_diff(), count_ * sizeof(Dtype));
+    memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+        count_ * sizeof(Dtype));
+    memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+        count_ * sizeof(Dtype));
   }
 }
 
@@ -72,25 +75,25 @@ const Dtype* Blob<Dtype>::gpu_diff() const {
 template <typename Dtype>
 Dtype* Blob<Dtype>::mutable_cpu_data() {
   CHECK(data_);
-  return (Dtype*)data_->mutable_cpu_data();
+  return reinterpret_cast<Dtype*>(data_->mutable_cpu_data());
 }
 
 template <typename Dtype>
 Dtype* Blob<Dtype>::mutable_gpu_data() {
   CHECK(data_);
-  return (Dtype*)data_->mutable_gpu_data();
+  return reinterpret_cast<Dtype*>(data_->mutable_gpu_data());
 }
 
 template <typename Dtype>
 Dtype* Blob<Dtype>::mutable_cpu_diff() {
   CHECK(diff_);
-  return (Dtype*)diff_->mutable_cpu_data();
+  return reinterpret_cast<Dtype*>(diff_->mutable_cpu_data());
 }
 
 template <typename Dtype>
 Dtype* Blob<Dtype>::mutable_gpu_diff() {
   CHECK(diff_);
-  return (Dtype*)diff_->mutable_gpu_data();
+  return reinterpret_cast<Dtype*>(diff_->mutable_gpu_data());
 }
 
 template <typename Dtype>
index 0b136d8..35e6d2c 100644 (file)
@@ -14,11 +14,11 @@ class Blob {
  public:
   Blob()
        : num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
-       diff_() {};
+       diff_() {}
   explicit Blob(const int num, const int channels, const int height,
     const int width);
   Blob(const Blob<Dtype>& source);
-  virtual ~Blob() {};
+  virtual ~Blob() {}
   void Reshape(const int num, const int height,
       const int width, const int channels);
   inline int num() const { return num_; }
index e3633ba..c9b9f65 100644 (file)
@@ -7,11 +7,11 @@
 #include <cublas_v2.h>
 #include <cuda.h>
 #include <curand.h>
+//cuda driver types
+#include <driver_types.h>
 #include <glog/logging.h>
 #include <mkl_vsl.h>
 
-#include "driver_types.h"
-
 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
index 388f284..024f8d7 100644 (file)
@@ -8,10 +8,12 @@
 #define CAFFE_FILLER_HPP
 
 #include <mkl.h>
+#include <string>
 
 #include "caffe/common.hpp"
 #include "caffe/blob.hpp"
 #include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
 #include "caffe/proto/layer_param.pb.h"
 
 namespace caffe {
@@ -19,22 +21,19 @@ namespace caffe {
 template <typename Dtype>
 class Filler {
  public:
-  Filler(const FillerParameter& param) : filler_param_(param) {};
-  virtual ~Filler() {};
+  explicit Filler(const FillerParameter& param) : filler_param_(param) {}
+  virtual ~Filler() {}
   virtual void Fill(Blob<Dtype>* blob) = 0;
  protected:
   FillerParameter filler_param_;
 };  // class Filler
 
-template <typename Dtype>
-class FillerFactory {
-
-};
 
 template <typename Dtype>
 class ConstantFiller : public Filler<Dtype> {
  public:
-  ConstantFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
+  explicit ConstantFiller(const FillerParameter& param)
+      : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
     Dtype* data = blob->mutable_cpu_data();
     const int count = blob->count();
@@ -49,53 +48,28 @@ class ConstantFiller : public Filler<Dtype> {
 template <typename Dtype>
 class UniformFiller : public Filler<Dtype> {
  public:
-  UniformFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
+  explicit UniformFiller(const FillerParameter& param)
+      : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
-    void* data = (void*)(blob->mutable_cpu_data());
-    const int count = blob->count();
-    const Dtype value = this->filler_param_.value();
-    CHECK(count);
-    switch(sizeof(Dtype)) {
-    case sizeof(float):
-      VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
-          count, (float*)data, this->filler_param_.min(),
-          this->filler_param_.max()));
-      break;
-    case sizeof(double):
-      VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
-          count, (double*)data, this->filler_param_.min(),
-          this->filler_param_.max()));
-      break;
-    default:
-      CHECK(false) << "Unknown dtype.";
-    }
-  };
+    DCHECK(blob->count());
+    caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+        Dtype(this->filler_param_.min()),
+        Dtype(this->filler_param_.max()));
+  }
 };
 
 template <typename Dtype>
 class GaussianFiller : public Filler<Dtype> {
  public:
-  GaussianFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
+  explicit GaussianFiller(const FillerParameter& param)
+      : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
-    void* data = (void*)(blob->mutable_cpu_data());
-    const int count = blob->count();
-    const Dtype value = this->filler_param_.value();
-    CHECK(count);
-    switch(sizeof(Dtype)) {
-    case sizeof(float):
-      VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
-          Caffe::vsl_stream(), count, (float*)data,
-          this->filler_param_.mean(), this->filler_param_.std()));
-      break;
-    case sizeof(double):
-      VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
-          Caffe::vsl_stream(), count, (double*)data,
-          this->filler_param_.mean(), this->filler_param_.std()));
-      break;
-    default:
-      CHECK(false) << "Unknown dtype.";
-    }
-  };
+    Dtype* data = blob->mutable_cpu_data();
+    DCHECK(blob->count());
+    caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
+        Dtype(this->filler_param_.mean()),
+        Dtype(this->filler_param_.std()));
+  }
 };
 
 // A function to get a specific filler from the specification given in
index 957b78c..130d3fb 100644 (file)
@@ -15,12 +15,12 @@ namespace caffe {
 template <typename Dtype>
 class Layer {
  public:
-   // You should not implement your own constructor. Any set up code should go
-   // to SetUp(), where the dimensions of the bottom blobs are provided to the
-   // layer.
+  // You should not implement your own constructor. Any set up code should go
+  // to SetUp(), where the dimensions of the bottom blobs are provided to the
+  // layer.
   explicit Layer(const LayerParameter& param)
-    : layer_param_(param) {};
-  virtual ~Layer() {};
+    : layer_param_(param) {}
+  virtual ~Layer() {}
   // SetUp: your function should implement this.
   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) = 0;
@@ -35,7 +35,9 @@ class Layer {
       vector<Blob<Dtype>*>* bottom);
 
   // Returns the vector of parameters.
-  vector<Blob<Dtype> >& params() { return blobs_; };
+  vector<Blob<Dtype> >& params() {
+    return blobs_;
+  }
 
  protected:
   // The protobuf that stores the layer parameters
@@ -73,7 +75,7 @@ class Layer {
 template <typename Dtype>
 inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
-  switch(Caffe::mode()) {
+  switch (Caffe::mode()) {
   case Caffe::CPU:
     Forward_cpu(bottom, top);
     break;
@@ -89,7 +91,7 @@ template <typename Dtype>
 inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  switch(Caffe::mode()) {
+  switch (Caffe::mode()) {
   case Caffe::CPU:
     return Backward_cpu(top, propagate_down, bottom);
   case Caffe::GPU:
index 06c9df5..4909d27 100644 (file)
@@ -3,6 +3,8 @@
 #ifndef CAFFE_LAYER_FACTORY_HPP_
 #define CAFFE_LAYER_FACTORY_HPP_
 
+#include <string>
+
 #include "caffe/layer.hpp"
 #include "caffe/vision_layers.hpp"
 #include "caffe/proto/layer_param.pb.h"
index cffb297..8e6996d 100644 (file)
@@ -12,14 +12,14 @@ SyncedMemory::~SyncedMemory() {
   if (cpu_ptr_) {
     CUDA_CHECK(cudaFreeHost(cpu_ptr_));
   }
-  
+
   if (gpu_ptr_) {
     CUDA_CHECK(cudaFree(gpu_ptr_));
   }
 }
 
 inline void SyncedMemory::to_cpu() {
-  switch(head_) {
+  switch (head_) {
   case UNINITIALIZED:
     CUDA_CHECK(cudaMallocHost(&cpu_ptr_, size_));
     memset(cpu_ptr_, 0, size_);
@@ -39,7 +39,7 @@ inline void SyncedMemory::to_cpu() {
 }
 
 inline void SyncedMemory::to_gpu() {
-  switch(head_) {
+  switch (head_) {
   case UNINITIALIZED:
     CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     CUDA_CHECK(cudaMemset(gpu_ptr_, 0, size_));
index 4c56afd..9cf3b87 100644 (file)
@@ -8,9 +8,9 @@ namespace caffe {
 class SyncedMemory {
  public:
   SyncedMemory()
-      : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED) {};
+      : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED) {}
   explicit SyncedMemory(size_t size)
-      : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED) {};
+      : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED) {}
   ~SyncedMemory();
   const void* cpu_data();
   const void* gpu_data();
index 0c33c65..5a44468 100644 (file)
@@ -157,4 +157,34 @@ template <>
 void caffe_powx<double>(const int n, const double* a, const double b,
     double* y) { vdPowx(n, a, b, y); }
 
+template <>
+void caffe_vRngUniform<float>(const int n, float* r,
+    const float a, const float b) {
+  VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
+      n, r, a, b));
+}
+
+template <>
+void caffe_vRngUniform<double>(const int n, double* r,
+    const double a, const double b) {
+  VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
+      n, r, a, b));
+}
+
+template <>
+void caffe_vRngGaussian<float>(const int n, float* r, const float a,
+    const float sigma){
+  VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
+      Caffe::vsl_stream(), n, r, a, sigma));
+}
+
+
+template <>
+void caffe_vRngGaussian<double>(const int n, double* r, const double a,
+    const double sigma){
+  VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
+      Caffe::vsl_stream(), n, r, a, sigma));
+}
+
+
 }  // namespace caffe
index 42304ba..0c03c59 100644 (file)
@@ -57,6 +57,13 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
 template <typename Dtype>
 void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
 
+template <typename Dtype>
+void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
+
+template <typename Dtype>
+void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
+    const Dtype sigma);
+
 }  // namespace caffe