data layer race condition bugfix
authorYangqing Jia <jiayq84@gmail.com>
Thu, 24 Oct 2013 17:47:46 +0000 (10:47 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 24 Oct 2013 17:47:46 +0000 (10:47 -0700)
examples/demo_mnist.cpp
src/caffe/layers/data_layer.cpp

index 9eb6204..11d3fc5 100644 (file)
@@ -25,6 +25,7 @@ int main(int argc, char** argv) {
     return 0;
   }
   google::InitGoogleLogging(argv[0]);
+  Caffe::DeviceQuery();
 
   if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
     LOG(ERROR) << "Using GPU";
index fd26c5d..9ed9516 100644 (file)
@@ -170,8 +170,9 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   // cpu_data calls so that the prefetch thread does not accidentally make
   // simultaneous cudaMalloc calls when the main thread is running. In some
   // GPUs this seems to cause failures if we do not so.
-  layer->prefetch_data_->mutable_cpu_data();
-  layer->prefetch_label_->mutable_cpu_data();
+  prefetch_data_->mutable_cpu_data();
+  prefetch_label_->mutable_cpu_data();
+  data_mean_.cpu_data();
   // LOG(INFO) << "Initializing prefetch";
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";