prefetcher race condition
authorYangqing Jia <jiayq84@gmail.com>
Thu, 24 Oct 2013 17:33:50 +0000 (10:33 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 24 Oct 2013 17:33:50 +0000 (10:33 -0700)
src/caffe/layers/data_layer.cpp

index 0d4b0b5..fd26c5d 100644 (file)
@@ -166,7 +166,12 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
     // Simply initialize an all-empty mean.
     data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
   }
-  // Now, start the prefetch thread.
+  // Now, start the prefetch thread. Before calling prefetch, we make two
+  // cpu_data calls so that the prefetch thread does not accidentally make
+  // simultaneous cudaMalloc calls when the main thread is running. In some
+  // GPUs this seems to cause failures if we do not so.
+  layer->prefetch_data_->mutable_cpu_data();
+  layer->prefetch_label_->mutable_cpu_data();
   // LOG(INFO) << "Initializing prefetch";
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";