data layer: random skip
authorYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 22:02:33 +0000 (14:02 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 22:02:33 +0000 (14:02 -0800)
src/caffe/layers/data_layer.cpp
src/caffe/proto/caffe.proto

index 12fd6d9..95656a4 100644 (file)
@@ -113,8 +113,8 @@ void* DataLayerPrefetch(void* layer_pointer) {
 template <typename Dtype>
 void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 0) << "Neuron Layer takes no input blobs.";
-  CHECK_EQ(top->size(), 2) << "Neuron Layer takes two blobs as output.";
+  CHECK_EQ(bottom.size(), 0) << "Data Layer takes no input blobs.";
+  CHECK_EQ(top->size(), 2) << "Data Layer takes two blobs as output.";
   // Initialize the leveldb
   leveldb::DB* db_temp;
   leveldb::Options options;
@@ -127,6 +127,17 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   db_.reset(db_temp);
   iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
   iter_->SeekToFirst();
+  // Check if we would need to randomly skip a few data points
+  if (this->layer_param_.rand_skip()) {
+    unsigned int skip = rand() % this->layer_param_.rand_skip();
+    LOG(INFO) << "Skipping first " << skip << " data points.";
+    while (skip-- > 0) {
+      iter_->Next();
+      if (!iter_->Valid()) {
+        iter_->SeekToFirst();
+      }
+    }
+  }
   // Read a data point, and use it to initialize the top blob.
   Datum datum;
   datum.ParseFromString(iter_->value().ToString());
index c2d301f..cbe306d 100644 (file)
@@ -79,6 +79,12 @@ message LayerParameter {
   repeated float blobs_lr = 51;
   // The weight decay that is multiplied on the global weight decay.
   repeated float weight_decay = 52;
+
+  // The rand_skip variable is for the data layer to skip a few data points
+  // to avoid all asynchronous sgd clients to start at the same point. The skip
+  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+  // be larger than the number of keys in the leveldb.
+  optional uint32 rand_skip = 53 [ default = 0 ];
 }
 
 message LayerConnection {