enable DataLayer to output unlabeled data
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 15 Apr 2014 22:12:57 +0000 (15:12 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 15 Apr 2014 22:12:57 +0000 (15:12 -0700)
examples/mnist/mnist_autoencoder_test.prototxt
examples/mnist/mnist_autoencoder_train.prototxt
include/caffe/vision_layers.hpp
src/caffe/layers/data_layer.cpp
src/caffe/layers/data_layer.cu

index bec7a3c..440ccd3 100644 (file)
@@ -1,7 +1,6 @@
 name: "MNISTAutoencoder"
 layers {
   top: "data"
-  top: "label"
   name: "data"
   type: DATA
   data_param {
index d5201eb..90d2cff 100644 (file)
@@ -1,7 +1,6 @@
 name: "MNISTAutoencoder"
 layers {
   top: "data"
-  top: "label"
   name: "data"
   type: DATA
   data_param {
index 7cd5159..9052604 100644 (file)
@@ -296,6 +296,7 @@ class DataLayer : public Layer<Dtype> {
   shared_ptr<Blob<Dtype> > prefetch_data_;
   shared_ptr<Blob<Dtype> > prefetch_label_;
   Blob<Dtype> data_mean_;
+  bool output_labels_;
 };
 
 template <typename Dtype>
index 399f771..8340259 100644 (file)
@@ -9,6 +9,7 @@
 
 #include "caffe/layer.hpp"
 #include "caffe/util/io.hpp"
+#include "caffe/util/math_functions.hpp"
 #include "caffe/vision_layers.hpp"
 
 using std::string;
@@ -23,7 +24,10 @@ void* DataLayerPrefetch(void* layer_pointer) {
   Datum datum;
   CHECK(layer->prefetch_data_);
   Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
-  Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
+  Dtype* top_label;
+  if (layer->output_labels_) {
+    top_label = layer->prefetch_label_->mutable_cpu_data();
+  }
   const Dtype scale = layer->layer_param_.data_param().scale();
   const int batch_size = layer->layer_param_.data_param().batch_size();
   const int crop_size = layer->layer_param_.data_param().crop_size();
@@ -105,7 +109,9 @@ void* DataLayerPrefetch(void* layer_pointer) {
       }
     }
 
-    top_label[item_id] = datum.label();
+    if (layer->output_labels_) {
+      top_label[item_id] = datum.label();
+    }
     // go to the next iter
     layer->iter_->Next();
     if (!layer->iter_->Valid()) {
@@ -128,7 +134,13 @@ template <typename Dtype>
 void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 0) << "Data Layer takes no input blobs.";
-  CHECK_EQ(top->size(), 2) << "Data Layer takes two blobs as output.";
+  CHECK_GE(top->size(), 1) << "Data Layer takes at least one blob as output.";
+  CHECK_LE(top->size(), 2) << "Data Layer takes at most two blobs as output.";
+  if (top->size() == 1) {
+    output_labels_ = false;
+  } else {
+    output_labels_ = true;
+  }
   // Initialize the leveldb
   leveldb::DB* db_temp;
   leveldb::Options options;
@@ -178,9 +190,11 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
       << (*top)[0]->width();
   // label
-  (*top)[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1);
-  prefetch_label_.reset(
-      new Blob<Dtype>(this->layer_param_.data_param().batch_size(), 1, 1, 1));
+  if (output_labels_) {
+    (*top)[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1);
+    prefetch_label_.reset(
+        new Blob<Dtype>(this->layer_param_.data_param().batch_size(), 1, 1, 1));
+  }
   // datum size
   datum_channels_ = datum.channels();
   datum_height_ = datum.height();
@@ -208,7 +222,9 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   // simultaneous cudaMalloc calls when the main thread is running. In some
   // GPUs this seems to cause failures if we do not so.
   prefetch_data_->mutable_cpu_data();
-  prefetch_label_->mutable_cpu_data();
+  if (output_labels_) {
+    prefetch_label_->mutable_cpu_data();
+  }
   data_mean_.cpu_data();
   DLOG(INFO) << "Initializing prefetch";
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
@@ -222,10 +238,12 @@ Dtype DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   // First, join the thread
   CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
   // Copy the data
-  memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
-      sizeof(Dtype) * prefetch_data_->count());
-  memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
-      sizeof(Dtype) * prefetch_label_->count());
+  caffe_copy(prefetch_data_->count(), prefetch_data_->cpu_data(),
+             (*top)[0]->mutable_cpu_data());
+  if (output_labels_) {
+    caffe_copy(prefetch_label_->count(), prefetch_label_->cpu_data(),
+               (*top)[1]->mutable_cpu_data());
+  }
   // Start a new prefetch thread
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
index 86f4757..15ef016 100644 (file)
@@ -24,9 +24,11 @@ Dtype DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
       prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
       cudaMemcpyHostToDevice));
-  CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
-      prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
-      cudaMemcpyHostToDevice));
+  if (output_labels_) {
+    CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
+        prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
+        cudaMemcpyHostToDevice));
+  }
   // Start a new prefetch thread
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";