make RNG function outputs the last argument per Google C++ style guidelines
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 8 Apr 2014 20:18:29 +0000 (13:18 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 9 Apr 2014 03:17:17 +0000 (20:17 -0700)
include/caffe/filler.hpp
include/caffe/util/math_functions.hpp
src/caffe/layers/dropout_layer.cpp
src/caffe/test/test_common.cpp
src/caffe/test/test_random_number_generator.cpp
src/caffe/util/math_functions.cpp

index 256a03b..50a397e 100644 (file)
@@ -51,9 +51,8 @@ class UniformFiller : public Filler<Dtype> {
       : Filler<Dtype>(param) {}
   virtual void Fill(Blob<Dtype>* blob) {
     CHECK(blob->count());
-    caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
-        Dtype(this->filler_param_.min()),
-        Dtype(this->filler_param_.max()));
+    caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
+        Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
   }
 };
 
@@ -65,9 +64,8 @@ class GaussianFiller : public Filler<Dtype> {
   virtual void Fill(Blob<Dtype>* blob) {
     Dtype* data = blob->mutable_cpu_data();
     CHECK(blob->count());
-    caffe_rng_gaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
-        Dtype(this->filler_param_.mean()),
-        Dtype(this->filler_param_.std()));
+    caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
+        Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
   }
 };
 
@@ -79,7 +77,7 @@ class PositiveUnitballFiller : public Filler<Dtype> {
   virtual void Fill(Blob<Dtype>* blob) {
     Dtype* data = blob->mutable_cpu_data();
     DCHECK(blob->count());
-    caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
+    caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
     // We expect the filler to not be called very frequently, so we will
     // just use a simple implementation
     int dim = blob->count() / blob->num();
@@ -113,8 +111,8 @@ class XavierFiller : public Filler<Dtype> {
     CHECK(blob->count());
     int fan_in = blob->count() / blob->num();
     Dtype scale = sqrt(Dtype(3) / fan_in);
-    caffe_rng_uniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
-        -scale, scale);
+    caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
+        blob->mutable_cpu_data());
   }
 };
 
index 23aa265..7129cf9 100644 (file)
@@ -109,14 +109,14 @@ template <typename Dtype>
 Dtype caffe_nextafter(const Dtype b);
 
 template <typename Dtype>
-void caffe_rng_uniform(const int n, Dtype* r, const Dtype a, const Dtype b);
+void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r);
 
 template <typename Dtype>
-void caffe_rng_gaussian(const int n, Dtype* r, const Dtype a,
-    const Dtype sigma);
+void caffe_rng_gaussian(const int n, const Dtype mu, const Dtype sigma,
+                        Dtype* r);
 
 template <typename Dtype>
-void caffe_rng_bernoulli(const int n, int* r, const Dtype p);
+void caffe_rng_bernoulli(const int n, const Dtype p, int* r);
 
 template <typename Dtype>
 void caffe_exp(const int n, const Dtype* a, Dtype* y);
index a9dd842..e0068fe 100644 (file)
@@ -32,7 +32,7 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const int count = bottom[0]->count();
   if (Caffe::phase() == Caffe::TRAIN) {
     // Create random numbers
-    caffe_rng_bernoulli(count, mask, 1. - threshold_);
+    caffe_rng_bernoulli(count, 1. - threshold_, mask);
     for (int i = 0; i < count; ++i) {
       top_data[i] = bottom_data[i] * mask[i] * scale_;
     }
index a043db1..f236d12 100644 (file)
@@ -36,16 +36,14 @@ TEST_F(CommonTest, TestRandSeedCPU) {
   SyncedMemory data_a(10 * sizeof(int));
   SyncedMemory data_b(10 * sizeof(int));
   Caffe::set_random_seed(1701);
-  caffe_rng_bernoulli(10,
-      reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
+  caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_a.mutable_cpu_data()));
 
   Caffe::set_random_seed(1701);
-  caffe_rng_bernoulli(10,
-      reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
+  caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_b.mutable_cpu_data()));
 
   for (int i = 0; i < 10; ++i) {
-    EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
-        ((const int*)(data_b.cpu_data()))[i]);
+    EXPECT_EQ(static_cast<const int*>(data_a.cpu_data())[i],
+        static_cast<const int*>(data_b.cpu_data())[i]);
   }
 }
 
index 664f82f..d068d03 100644 (file)
@@ -60,7 +60,7 @@ class RandomNumberGeneratorTest : public ::testing::Test {
 
   void RngGaussianTest(const Dtype mu, const Dtype sigma, void* cpu_data) {
     Dtype* rng_data = static_cast<Dtype*>(cpu_data);
-    caffe_rng_gaussian(sample_size_, rng_data, mu, sigma);
+    caffe_rng_gaussian(sample_size_, mu, sigma, rng_data);
     const Dtype true_mean = mu;
     const Dtype true_std = sigma;
     // Check that sample mean roughly matches true mean.
@@ -90,7 +90,7 @@ class RandomNumberGeneratorTest : public ::testing::Test {
   void RngUniformTest(const Dtype lower, const Dtype upper, void* cpu_data) {
     CHECK_GE(upper, lower);
     Dtype* rng_data = static_cast<Dtype*>(cpu_data);
-    caffe_rng_uniform(sample_size_, rng_data, lower, upper);
+    caffe_rng_uniform(sample_size_, lower, upper, rng_data);
     const Dtype true_mean = (lower + upper) / 2;
     const Dtype true_std = (upper - lower) / sqrt(12);
     // Check that sample mean roughly matches true mean.
@@ -128,7 +128,7 @@ class RandomNumberGeneratorTest : public ::testing::Test {
 
   void RngBernoulliTest(const Dtype p, void* cpu_data) {
     int* rng_data = static_cast<int*>(cpu_data);
-    caffe_rng_bernoulli(sample_size_, rng_data, p);
+    caffe_rng_bernoulli(sample_size_, p, rng_data);
     const Dtype true_mean = p;
     const Dtype true_std = sqrt(p * (1 - p));
     const Dtype bound = this->mean_bound(true_std);
index a524ed3..c26675f 100644 (file)
@@ -315,8 +315,7 @@ template
 double caffe_nextafter(const double b);
 
 template <typename Dtype>
-void caffe_rng_uniform(const int n, Dtype* r,
-    const Dtype a, const Dtype b) {
+void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r) {
   CHECK_GE(n, 0);
   CHECK(r);
   CHECK_LE(a, b);
@@ -330,15 +329,16 @@ void caffe_rng_uniform(const int n, Dtype* r,
 }
 
 template
-void caffe_rng_uniform<float>(const int n, float* r,
-                                       const float a, const float b);
+void caffe_rng_uniform<float>(const int n, const float a, const float b,
+                              float* r);
+
 template
-void caffe_rng_uniform<double>(const int n, double* r,
-                                       const double a, const double b);
+void caffe_rng_uniform<double>(const int n, const double a, const double b,
+                               double* r);
 
 template <typename Dtype>
-void caffe_rng_gaussian(const int n, Dtype* r, const Dtype a,
-    const Dtype sigma) {
+void caffe_rng_gaussian(const int n, const Dtype a,
+                        const Dtype sigma, Dtype* r) {
   CHECK_GE(n, 0);
   CHECK(r);
   CHECK_GT(sigma, 0);
@@ -352,15 +352,15 @@ void caffe_rng_gaussian(const int n, Dtype* r, const Dtype a,
 }
 
 template
-void caffe_rng_gaussian<float>(const int n, float* r, const float a,
-    const float sigma);
+void caffe_rng_gaussian<float>(const int n, const float mu,
+                               const float sigma, float* r);
 
 template
-void caffe_rng_gaussian<double>(const int n, double* r, const double a,
-    const double sigma);
+void caffe_rng_gaussian<double>(const int n, const double mu,
+                                const double sigma, double* r);
 
 template <typename Dtype>
-void caffe_rng_bernoulli(const int n, int* r, const Dtype p) {
+void caffe_rng_bernoulli(const int n, const Dtype p, int* r) {
   CHECK_GE(n, 0);
   CHECK(r);
   CHECK_GE(p, 0);
@@ -375,10 +375,10 @@ void caffe_rng_bernoulli(const int n, int* r, const Dtype p) {
 }
 
 template
-void caffe_rng_bernoulli<double>(const int n, int* r, const double p);
+void caffe_rng_bernoulli<double>(const int n, const double p, int* r);
 
 template
-void caffe_rng_bernoulli<float>(const int n, int* r, const float p);
+void caffe_rng_bernoulli<float>(const int n, const float p, int* r);
 
 template <>
 float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {