From: Yangqing Jia Date: Wed, 6 Nov 2013 22:02:33 +0000 (-0800) Subject: data layer: random skip X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5d14903df5c10145907730cfebffcfafbb91b32c;p=platform%2Fupstream%2Fcaffe.git data layer: random skip --- diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 12fd6d9..95656a4 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -113,8 +113,8 @@ void* DataLayerPrefetch(void* layer_pointer) { template void DataLayer::SetUp(const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom.size(), 0) << "Neuron Layer takes no input blobs."; - CHECK_EQ(top->size(), 2) << "Neuron Layer takes two blobs as output."; + CHECK_EQ(bottom.size(), 0) << "Data Layer takes no input blobs."; + CHECK_EQ(top->size(), 2) << "Data Layer takes two blobs as output."; // Initialize the leveldb leveldb::DB* db_temp; leveldb::Options options; @@ -127,6 +127,17 @@ void DataLayer::SetUp(const vector*>& bottom, db_.reset(db_temp); iter_.reset(db_->NewIterator(leveldb::ReadOptions())); iter_->SeekToFirst(); + // Check if we would need to randomly skip a few data points + if (this->layer_param_.rand_skip()) { + unsigned int skip = rand() % this->layer_param_.rand_skip(); + LOG(INFO) << "Skipping first " << skip << " data points."; + while (skip-- > 0) { + iter_->Next(); + if (!iter_->Valid()) { + iter_->SeekToFirst(); + } + } + } // Read a data point, and use it to initialize the top blob. Datum datum; datum.ParseFromString(iter_->value().ToString()); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c2d301f..cbe306d 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -79,6 +79,12 @@ message LayerParameter { repeated float blobs_lr = 51; // The weight decay that is multiplied on the global weight decay. repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the leveldb. + optional uint32 rand_skip = 53 [ default = 0 ]; } message LayerConnection {