post rebase fixes: images layer and padding layer compute loss in
authorJeff Donahue <jeff.donahue@gmail.com>
Wed, 19 Mar 2014 01:55:56 +0000 (18:55 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 19 Mar 2014 19:37:31 +0000 (12:37 -0700)
forward

include/caffe/vision_layers.hpp
src/caffe/layers/images_layer.cpp
src/caffe/layers/padding_layer.cpp
src/caffe/layers/padding_layer.cu

index 3a3bdfd..9c0850e 100644 (file)
@@ -206,13 +206,13 @@ class PaddingLayer : public Layer<Dtype> {
       vector<Blob<Dtype>*>* top);
 
  protected:
-  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   unsigned int PAD_;
   int NUM_;
@@ -425,14 +425,14 @@ class ImagesLayer : public Layer<Dtype> {
       vector<Blob<Dtype>*>* top);
 
  protected:
-  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
 
   vector<std::pair<std::string, int> > lines_;
   int lines_id_;
index e750e01..6208a9e 100644 (file)
@@ -233,7 +233,7 @@ void ImagesLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void ImagesLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+Dtype ImagesLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   // First, join the thread
   CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
@@ -245,10 +245,11 @@ void ImagesLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   // Start a new prefetch thread
   CHECK(!pthread_create(&thread_, NULL, ImagesLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
+  return Dtype(0.);
 }
 
 template <typename Dtype>
-void ImagesLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+Dtype ImagesLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   // First, join the thread
   CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
@@ -262,18 +263,6 @@ void ImagesLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   // Start a new prefetch thread
   CHECK(!pthread_create(&thread_, NULL, ImagesLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
-}
-
-// The backward operations are dummy - they do not carry any computation.
-template <typename Dtype>
-Dtype ImagesLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  return Dtype(0.);
-}
-
-template <typename Dtype>
-Dtype ImagesLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
   return Dtype(0.);
 }
 
index 4cb67df..658cc6a 100644 (file)
@@ -29,7 +29,7 @@ void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+Dtype PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
   const Dtype* bottom_data = bottom[0]->cpu_data();
@@ -47,10 +47,11 @@ void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       }
     }
   }
+  return Dtype(0.);
 }
 
 template <typename Dtype>
-Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+void PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
@@ -66,7 +67,6 @@ Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       }
     }
   }
-  return Dtype(0.);
 }
 
 INSTANTIATE_CLASS(PaddingLayer);
index 7ec28a9..d476df5 100644 (file)
@@ -27,7 +27,7 @@ __global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
 }
 
 template <typename Dtype>
-void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+Dtype PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -39,6 +39,7 @@ void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
       PAD_);
   CUDA_POST_KERNEL_CHECK;
+  return Dtype(0);
 }
 
 template <typename Dtype>
@@ -61,7 +62,7 @@ __global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
 }
 
 template <typename Dtype>
-Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down) {
@@ -74,7 +75,6 @@ Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         PAD_);
     CUDA_POST_KERNEL_CHECK;
   }
-  return Dtype(0);
 }
 
 INSTANTIATE_CLASS(PaddingLayer);