do the same as prev commit for ImageDataLayer
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 22 Apr 2014 03:43:39 +0000 (20:43 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 22 Apr 2014 03:43:39 +0000 (20:43 -0700)
include/caffe/vision_layers.hpp
src/caffe/layers/image_data_layer.cpp
src/caffe/layers/image_data_layer.cu

index f4abcf4..6084c27 100644 (file)
@@ -477,6 +477,7 @@ class ImageDataLayer : public Layer<Dtype> {
   shared_ptr<Blob<Dtype> > prefetch_data_;
   shared_ptr<Blob<Dtype> > prefetch_label_;
   Blob<Dtype> data_mean_;
+  Caffe::Phase phase_;
 };
 
 template <typename Dtype>
index 8c23cc4..4182091 100644 (file)
@@ -60,7 +60,7 @@ void* ImageDataLayerPrefetch(void* layer_pointer) {
       CHECK(data.size()) << "Image cropping only support uint8 data";
       int h_off, w_off;
       // We only do random crop when we do training.
-      if (Caffe::phase() == Caffe::TRAIN) {
+      if (layer->phase_ == Caffe::TRAIN) {
         // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
         h_off = rand() % (height - crop_size);
         // NOLINT_NEXT_LINE(runtime/threadsafe_fn)
@@ -228,6 +228,7 @@ void ImageDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   prefetch_label_->mutable_cpu_data();
   data_mean_.cpu_data();
   DLOG(INFO) << "Initializing prefetch";
+  phase_ = Caffe::phase();
   CHECK(!pthread_create(&thread_, NULL, ImageDataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
   DLOG(INFO) << "Prefetch initialized.";
@@ -244,6 +245,7 @@ Dtype ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
       sizeof(Dtype) * prefetch_label_->count());
   // Start a new prefetch thread
+  phase_ = Caffe::phase();
   CHECK(!pthread_create(&thread_, NULL, ImageDataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
   return Dtype(0.);
index 7b4952d..e95550f 100644 (file)
@@ -34,6 +34,7 @@ Dtype ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
       cudaMemcpyHostToDevice));
   // Start a new prefetch thread
+  phase_ = Caffe::phase();
   CHECK(!pthread_create(&thread_, NULL, ImageDataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
   return Dtype(0.);