From 078e0bf713bcc4f9178f6694c66b368302448484 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Tue, 8 Apr 2014 11:57:25 -0700 Subject: [PATCH] make RNG function names more similar to other caffe math function names --- include/caffe/filler.hpp | 8 ++++---- include/caffe/util/math_functions.hpp | 6 +++--- src/caffe/layers/dropout_layer.cpp | 2 +- src/caffe/test/test_common.cpp | 4 ++-- src/caffe/test/test_random_number_generator.cpp | 6 +++--- src/caffe/util/math_functions.cpp | 18 +++++++++--------- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index ba473e1..256a03b 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -51,7 +51,7 @@ class UniformFiller : public Filler { : Filler(param) {} virtual void Fill(Blob* blob) { CHECK(blob->count()); - caffe_vRngUniform(blob->count(), blob->mutable_cpu_data(), + caffe_rng_uniform(blob->count(), blob->mutable_cpu_data(), Dtype(this->filler_param_.min()), Dtype(this->filler_param_.max())); } @@ -65,7 +65,7 @@ class GaussianFiller : public Filler { virtual void Fill(Blob* blob) { Dtype* data = blob->mutable_cpu_data(); CHECK(blob->count()); - caffe_vRngGaussian(blob->count(), blob->mutable_cpu_data(), + caffe_rng_gaussian(blob->count(), blob->mutable_cpu_data(), Dtype(this->filler_param_.mean()), Dtype(this->filler_param_.std())); } @@ -79,7 +79,7 @@ class PositiveUnitballFiller : public Filler { virtual void Fill(Blob* blob) { Dtype* data = blob->mutable_cpu_data(); DCHECK(blob->count()); - caffe_vRngUniform(blob->count(), blob->mutable_cpu_data(), 0, 1); + caffe_rng_uniform(blob->count(), blob->mutable_cpu_data(), 0, 1); // We expect the filler to not be called very frequently, so we will // just use a simple implementation int dim = blob->count() / blob->num(); @@ -113,7 +113,7 @@ class XavierFiller : public Filler { CHECK(blob->count()); int fan_in = blob->count() / blob->num(); Dtype scale = sqrt(Dtype(3) / fan_in); - caffe_vRngUniform(blob->count(), blob->mutable_cpu_data(), + caffe_rng_uniform(blob->count(), blob->mutable_cpu_data(), -scale, scale); } }; diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 77e3234..23aa265 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -109,14 +109,14 @@ template Dtype caffe_nextafter(const Dtype b); template -void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b); +void caffe_rng_uniform(const int n, Dtype* r, const Dtype a, const Dtype b); template -void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, +void caffe_rng_gaussian(const int n, Dtype* r, const Dtype a, const Dtype sigma); template -void caffe_vRngBernoulli(const int n, int* r, const Dtype p); +void caffe_rng_bernoulli(const int n, int* r, const Dtype p); template void caffe_exp(const int n, const Dtype* a, Dtype* y); diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index a57999c..a9dd842 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -32,7 +32,7 @@ Dtype DropoutLayer::Forward_cpu(const vector*>& bottom, const int count = bottom[0]->count(); if (Caffe::phase() == Caffe::TRAIN) { // Create random numbers - caffe_vRngBernoulli(count, mask, 1. - threshold_); + caffe_rng_bernoulli(count, mask, 1. - threshold_); for (int i = 0; i < count; ++i) { top_data[i] = bottom_data[i] * mask[i] * scale_; } diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp index 7839c37..a043db1 100644 --- a/src/caffe/test/test_common.cpp +++ b/src/caffe/test/test_common.cpp @@ -36,11 +36,11 @@ TEST_F(CommonTest, TestRandSeedCPU) { SyncedMemory data_a(10 * sizeof(int)); SyncedMemory data_b(10 * sizeof(int)); Caffe::set_random_seed(1701); - caffe_vRngBernoulli(10, + caffe_rng_bernoulli(10, reinterpret_cast(data_a.mutable_cpu_data()), 0.5); Caffe::set_random_seed(1701); - caffe_vRngBernoulli(10, + caffe_rng_bernoulli(10, reinterpret_cast(data_b.mutable_cpu_data()), 0.5); for (int i = 0; i < 10; ++i) { diff --git a/src/caffe/test/test_random_number_generator.cpp b/src/caffe/test/test_random_number_generator.cpp index c1425f2..664f82f 100644 --- a/src/caffe/test/test_random_number_generator.cpp +++ b/src/caffe/test/test_random_number_generator.cpp @@ -60,7 +60,7 @@ class RandomNumberGeneratorTest : public ::testing::Test { void RngGaussianTest(const Dtype mu, const Dtype sigma, void* cpu_data) { Dtype* rng_data = static_cast(cpu_data); - caffe_vRngGaussian(sample_size_, rng_data, mu, sigma); + caffe_rng_gaussian(sample_size_, rng_data, mu, sigma); const Dtype true_mean = mu; const Dtype true_std = sigma; // Check that sample mean roughly matches true mean. @@ -90,7 +90,7 @@ class RandomNumberGeneratorTest : public ::testing::Test { void RngUniformTest(const Dtype lower, const Dtype upper, void* cpu_data) { CHECK_GE(upper, lower); Dtype* rng_data = static_cast(cpu_data); - caffe_vRngUniform(sample_size_, rng_data, lower, upper); + caffe_rng_uniform(sample_size_, rng_data, lower, upper); const Dtype true_mean = (lower + upper) / 2; const Dtype true_std = (upper - lower) / sqrt(12); // Check that sample mean roughly matches true mean. @@ -128,7 +128,7 @@ class RandomNumberGeneratorTest : public ::testing::Test { void RngBernoulliTest(const Dtype p, void* cpu_data) { int* rng_data = static_cast(cpu_data); - caffe_vRngBernoulli(sample_size_, rng_data, p); + caffe_rng_bernoulli(sample_size_, rng_data, p); const Dtype true_mean = p; const Dtype true_std = sqrt(p * (1 - p)); const Dtype bound = this->mean_bound(true_std); diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 0791a86..a524ed3 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -315,7 +315,7 @@ template double caffe_nextafter(const double b); template -void caffe_vRngUniform(const int n, Dtype* r, +void caffe_rng_uniform(const int n, Dtype* r, const Dtype a, const Dtype b) { CHECK_GE(n, 0); CHECK(r); @@ -330,14 +330,14 @@ void caffe_vRngUniform(const int n, Dtype* r, } template -void caffe_vRngUniform(const int n, float* r, +void caffe_rng_uniform(const int n, float* r, const float a, const float b); template -void caffe_vRngUniform(const int n, double* r, +void caffe_rng_uniform(const int n, double* r, const double a, const double b); template -void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, +void caffe_rng_gaussian(const int n, Dtype* r, const Dtype a, const Dtype sigma) { CHECK_GE(n, 0); CHECK(r); @@ -352,15 +352,15 @@ void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, } template -void caffe_vRngGaussian(const int n, float* r, const float a, +void caffe_rng_gaussian(const int n, float* r, const float a, const float sigma); template -void caffe_vRngGaussian(const int n, double* r, const double a, +void caffe_rng_gaussian(const int n, double* r, const double a, const double sigma); template -void caffe_vRngBernoulli(const int n, int* r, const Dtype p) { +void caffe_rng_bernoulli(const int n, int* r, const Dtype p) { CHECK_GE(n, 0); CHECK(r); CHECK_GE(p, 0); @@ -375,10 +375,10 @@ void caffe_vRngBernoulli(const int n, int* r, const Dtype p) { } template -void caffe_vRngBernoulli(const int n, int* r, const double p); +void caffe_rng_bernoulli(const int n, int* r, const double p); template -void caffe_vRngBernoulli(const int n, int* r, const float p); +void caffe_rng_bernoulli(const int n, int* r, const float p); template <> float caffe_cpu_dot(const int n, const float* x, const float* y) { -- 2.7.4