#include <stdint.h>
#include <leveldb/db.h>
+#include <pthread.h>
#include <string>
#include <vector>
namespace caffe {
template <typename Dtype>
+void* DataLayerPrefetch(void* layer_pointer) {
+ DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
+ Datum datum;
+ Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
+ Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
+ const Dtype scale = layer->layer_param_.scale();
+ const Dtype subtraction = layer->layer_param_.subtraction();
+ const int batchsize = layer->layer_param_.batchsize();
+ const int cropsize = layer->layer_param_.cropsize();
+ for (int itemid = 0; itemid < batchsize; ++itemid) {
+ // get a blob
+ datum.ParseFromString(layer->iter_->value().ToString());
+ const string& data = datum.data();
+ if (cropsize) {
+ CHECK(data.size()) << "Image cropping only support uint8 data";
+ int h_offset = rand() % (layer->datum_height_ - cropsize);
+ int w_offset = rand() % (layer->datum_width_ - cropsize);
+ for (int c = 0; c < layer->datum_channels_; ++c) {
+ for (int h = 0; h < cropsize; ++h) {
+ for (int w = 0; w < cropsize; ++w) {
+ top_data[((itemid * layer->datum_channels_ + c) * cropsize + h) * cropsize + w] =
+ static_cast<Dtype>((uint8_t)data[
+ (c * layer->datum_height_ + h + h_offset) * layer->datum_width_
+ + w + w_offset]
+ ) * scale - subtraction;
+ }
+ }
+ }
+ } else {
+ // we will prefer to use data() first, and then try float_data()
+ if (data.size()) {
+ for (int j = 0; j < layer->datum_size_; ++j) {
+ top_data[itemid * layer->datum_size_ + j] =
+ (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+ }
+ } else {
+ for (int j = 0; j < layer->datum_size_; ++j) {
+ top_data[itemid * layer->datum_size_ + j] =
+ (datum.float_data(j) * scale) - subtraction;
+ }
+ }
+ }
+ top_label[itemid] = datum.label();
+ // go to the next iter
+ layer->iter_->Next();
+ if (!layer->iter_->Valid()) {
+ // We have reached the end. Restart from the first.
+ LOG(INFO) << "Restarting data read from start.";
+ layer->iter_->SeekToFirst();
+ }
+ }
+}
+
+
+template <typename Dtype>
void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 0) << "Neuron Layer takes no input blobs.";
if (cropsize > 0) {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
+ prefetch_data_.reset(new Blob<Dtype>(
+ this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
} else {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), datum.channels(), datum.height(),
datum.width());
+ prefetch_data_.reset(new Blob<Dtype>(
+ this->layer_param_.batchsize(), datum.channels(), datum.height(),
+ datum.width()));
}
LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
<< (*top)[0]->channels() << "," << (*top)[0]->height() << ","
<< (*top)[0]->width();
// label
(*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
+ prefetch_label_.reset(
+ new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
// datum size
datum_channels_ = datum.channels();
datum_height_ = datum.height();
datum_size_ = datum.channels() * datum.height() * datum.width();
CHECK_GT(datum_height_, cropsize);
CHECK_GT(datum_width_, cropsize);
+ // Now, start the prefetch thread.
+ //LOG(INFO) << "Initializing prefetch";
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
+ << "Pthread execution failed.";
+ //LOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>
void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Datum datum;
- Dtype* top_data = (*top)[0]->mutable_cpu_data();
- Dtype* top_label = (*top)[1]->mutable_cpu_data();
- const Dtype scale = this->layer_param_.scale();
- const Dtype subtraction = this->layer_param_.subtraction();
- int cropsize = this->layer_param_.cropsize();
- for (int itemid = 0; itemid < (*top)[0]->num(); ++itemid) {
- // get a blob
- datum.ParseFromString(iter_->value().ToString());
- const string& data = datum.data();
- if (cropsize) {
- CHECK(data.size()) << "Image cropping only support uint8 data";
- int h_offset = rand() % (datum_height_ - cropsize);
- int w_offset = rand() % (datum_width_ - cropsize);
- for (int c = 0; c < datum_channels_; ++c) {
- for (int h = 0; h < cropsize; ++h) {
- for (int w = 0; w < cropsize; ++w) {
- top_data[((itemid * datum_channels_ + c) * cropsize + h) * cropsize + w] =
- static_cast<Dtype>((uint8_t)data[
- (c * datum_height_ + h + h_offset) * datum_width_
- + w + w_offset]
- ) * scale - subtraction;
- }
- }
- }
- } else {
- // we will prefer to use data() first, and then try float_data()
- if (data.size()) {
- for (int j = 0; j < datum_size_; ++j) {
- top_data[itemid * datum_size_ + j] =
- (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
- }
- } else {
- for (int j = 0; j < datum_size_; ++j) {
- top_data[itemid * datum_size_ + j] =
- (datum.float_data(j) * scale) - subtraction;
- }
- }
- }
- top_label[itemid] = datum.label();
- // go to the next iter
- iter_->Next();
- if (!iter_->Valid()) {
- // We have reached the end. Restart from the first.
- LOG(INFO) << "Restarting data read from start.";
- iter_->SeekToFirst();
- }
- }
+ // First, join the thread
+ CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
+ // Copy the data
+ memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
+ sizeof(Dtype) * prefetch_data_->count());
+ memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
+ sizeof(Dtype) * prefetch_label_->count());
+ // Start a new prefetch thread
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
+ << "Pthread execution failed.";
}
template <typename Dtype>
void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Forward_cpu(bottom, top);
- // explicitly copy data to gpu - this is achieved by simply calling gpu_data
- // functions.
- // TODO(Yangqing): maybe we don't need this since data synchronization is
- // simply done under the hood?
- (*top)[0]->gpu_data();
- (*top)[1]->gpu_data();
+ // First, join the thread
+ CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
+ // Copy the data
+ CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(),
+ sizeof(Dtype) * prefetch_data_->count(), cudaMemcpyHostToDevice));
+ CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(),
+ sizeof(Dtype) * prefetch_label_->count(), cudaMemcpyHostToDevice));
+ // Start a new prefetch thread
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
+ << "Pthread execution failed.";
}
// The backward operations are dummy - they do not carry any computation.