Fix and improve multiple places about random number generation and
authornetheril96 <netheril96@gmail.com>
Mon, 4 Aug 2014 04:28:21 +0000 (12:28 +0800)
committernetheril96 <netheril96@gmail.com>
Mon, 4 Aug 2014 04:41:58 +0000 (12:41 +0800)
shuffling

include/caffe/util/rng.hpp
src/caffe/common.cpp
src/caffe/layers/image_data_layer.cpp
tools/convert_imageset.cpp

index 5909d17..2b79ef1 100644 (file)
@@ -4,16 +4,40 @@
 #define CAFFE_RNG_CPP_HPP_
 
 #include <boost/random/mersenne_twister.hpp>
+#include <boost/random/uniform_int.hpp>
+#include <iterator>
+#include <algorithm>
 #include "caffe/common.hpp"
 
 namespace caffe {
 
-  typedef boost::mt19937 rng_t;
+typedef boost::mt19937 rng_t;
 
-  inline rng_t* caffe_rng() {
-    return static_cast<caffe::rng_t*>(Caffe::rng_stream().generator());
+inline rng_t* caffe_rng() {
+  return static_cast<caffe::rng_t*>(Caffe::rng_stream().generator());
+}
+
+// Fisher–Yates algorithm
+template <class RandomAccessIterator, class RandomGenerator>
+inline void shuffle(RandomAccessIterator begin, RandomAccessIterator end,
+                    RandomGenerator* gen) {
+  typedef typename std::iterator_traits<RandomAccessIterator>::difference_type
+      difference_type;
+  typedef typename boost::uniform_int<difference_type> dist_type;
+
+  difference_type length = std::distance(begin, end);
+  if (length <= 0) return;
+
+  for (difference_type i = length - 1; i > 0; --i) {
+    dist_type dist(0, i);
+    std::iter_swap(begin + i, begin + dist(*gen));
   }
+}
 
+template <class RandomAccessIterator>
+inline void shuffle(RandomAccessIterator begin, RandomAccessIterator end) {
+  shuffle(begin, end, caffe_rng());
+}
 }  // namespace caffe
 
 #endif  // CAFFE_RNG_HPP_
index 1810bb7..019370b 100644 (file)
@@ -26,6 +26,17 @@ shared_ptr<Caffe> Caffe::singleton_;
 // random seeding
 int64_t cluster_seedgen(void) {
   int64_t s, seed, pid;
+  FILE* f = fopen("/dev/urandom", "rb");
+  if (f && fread(&seed, 1, sizeof(seed), f) == sizeof(seed)) {
+    fclose(f);
+    return seed;
+  }
+
+  LOG(INFO) << "System entropy source not available, "
+              "using fallback algorithm to generate seed instead.";
+  if (f)
+    fclose(f);
+
   pid = getpid();
   s = time(NULL);
   seed = abs(((s * 181) * ((pid - 83) * 359)) % 104729);
@@ -75,7 +86,7 @@ Caffe::RNG::RNG() : generator_(new Generator()) { }
 Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { }
 
 Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
-  generator_.reset(other.generator_.get());
+  generator_ = other.generator_;
   return *this;
 }
 
index efd5938..a0f91af 100644 (file)
@@ -238,14 +238,9 @@ void ImageDataLayer<Dtype>::CreatePrefetchThread() {
 
 template <typename Dtype>
 void ImageDataLayer<Dtype>::ShuffleImages() {
-  const int num_images = lines_.size();
-  for (int i = 0; i < num_images; ++i) {
-    const int max_rand_index = num_images - i;
-    const int rand_index = PrefetchRand() % max_rand_index;
-    pair<string, int> item = lines_[rand_index];
-    lines_.erase(lines_.begin() + rand_index);
-    lines_.push_back(item);
-  }
+  caffe::rng_t* prefetch_rng =
+      static_cast<caffe::rng_t*>(prefetch_rng_->generator());
+  shuffle(lines_.begin(), lines_.end(), prefetch_rng);
 }
 
 
index 28de653..21ff0d3 100644 (file)
@@ -28,6 +28,7 @@
 
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
+#include "caffe/util/rng.hpp"
 
 using namespace caffe;  // NOLINT(build/namespaces)
 using std::pair;
@@ -60,7 +61,7 @@ int main(int argc, char** argv) {
   if (argc >= (arg_offset+5) && argv[arg_offset+4][0] == '1') {
     // randomly shuffle data
     LOG(INFO) << "Shuffling data";
-    std::random_shuffle(lines.begin(), lines.end());
+    shuffle(lines.begin(), lines.end());
   }
   LOG(INFO) << "A total of " << lines.size() << " images.";