layer definition reorganization and documentation
authorSergey Karayev <sergeykarayev@gmail.com>
Tue, 29 Apr 2014 02:06:07 +0000 (19:06 -0700)
committerSergey Karayev <sergeykarayev@gmail.com>
Fri, 16 May 2014 17:58:38 +0000 (10:58 -0700)
- split out neuron, loss, and data layers into own header files
- added LossLayer class with common SetUp checks
- in-progress concise documentation of each layer's purpose

include/caffe/data_layers.hpp [new file with mode: 0644]
include/caffe/loss_layers.hpp [new file with mode: 0644]
include/caffe/neuron_layers.hpp [new file with mode: 0644]
include/caffe/vision_layers.hpp
src/caffe/layers/loss_layer.cpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp

diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp
new file mode 100644 (file)
index 0000000..a1742fe
--- /dev/null
@@ -0,0 +1,215 @@
+// Copyright 2014 BVLC and contributors.
+
+#ifndef CAFFE_DATA_LAYERS_HPP_
+#define CAFFE_DATA_LAYERS_HPP_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "leveldb/db.h"
+#include "pthread.h"
+#include "hdf5.h"
+#include "boost/scoped_ptr.hpp"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+#define HDF5_DATA_DATASET_NAME "data"
+#define HDF5_DATA_LABEL_NAME "label"
+
+// This function is used to create a pthread that prefetches the data.
+template <typename Dtype>
+void* DataLayerPrefetch(void* layer_pointer);
+
+template <typename Dtype>
+class DataLayer : public Layer<Dtype> {
+  // The function used to perform prefetching.
+  friend void* DataLayerPrefetch<Dtype>(void* layer_pointer);
+
+ public:
+  explicit DataLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual ~DataLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+
+  virtual void CreatePrefetchThread();
+  virtual void JoinPrefetchThread();
+  virtual unsigned int PrefetchRand();
+
+  shared_ptr<Caffe::RNG> prefetch_rng_;
+  shared_ptr<leveldb::DB> db_;
+  shared_ptr<leveldb::Iterator> iter_;
+  int datum_channels_;
+  int datum_height_;
+  int datum_width_;
+  int datum_size_;
+  pthread_t thread_;
+  shared_ptr<Blob<Dtype> > prefetch_data_;
+  shared_ptr<Blob<Dtype> > prefetch_label_;
+  Blob<Dtype> data_mean_;
+  bool output_labels_;
+  Caffe::Phase phase_;
+};
+
+template <typename Dtype>
+class HDF5OutputLayer : public Layer<Dtype> {
+ public:
+  explicit HDF5OutputLayer(const LayerParameter& param);
+  virtual ~HDF5OutputLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  inline std::string file_name() const { return file_name_; }
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void SaveBlobs();
+
+  std::string file_name_;
+  hid_t file_id_;
+  Blob<Dtype> data_blob_;
+  Blob<Dtype> label_blob_;
+};
+
+
+template <typename Dtype>
+class HDF5DataLayer : public Layer<Dtype> {
+ public:
+  explicit HDF5DataLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual ~HDF5DataLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void LoadHDF5FileData(const char* filename);
+
+  std::vector<std::string> hdf_filenames_;
+  unsigned int num_files_;
+  unsigned int current_file_;
+  hsize_t current_row_;
+  Blob<Dtype> data_blob_;
+  Blob<Dtype> label_blob_;
+};
+
+// This function is used to create a pthread that prefetches the data.
+template <typename Dtype>
+void* ImageDataLayerPrefetch(void* layer_pointer);
+
+template <typename Dtype>
+class ImageDataLayer : public Layer<Dtype> {
+  // The function used to perform prefetching.
+  friend void* ImageDataLayerPrefetch<Dtype>(void* layer_pointer);
+
+ public:
+  explicit ImageDataLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual ~ImageDataLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+
+  virtual void ShuffleImages();
+
+  virtual void CreatePrefetchThread();
+  virtual void JoinPrefetchThread();
+  virtual unsigned int PrefetchRand();
+
+  shared_ptr<Caffe::RNG> prefetch_rng_;
+  vector<std::pair<std::string, int> > lines_;
+  int lines_id_;
+  int datum_channels_;
+  int datum_height_;
+  int datum_width_;
+  int datum_size_;
+  pthread_t thread_;
+  shared_ptr<Blob<Dtype> > prefetch_data_;
+  shared_ptr<Blob<Dtype> > prefetch_label_;
+  Blob<Dtype> data_mean_;
+  Caffe::Phase phase_;
+};
+
+
+// This function is used to create a pthread that prefetches the window data.
+template <typename Dtype>
+void* WindowDataLayerPrefetch(void* layer_pointer);
+
+template <typename Dtype>
+class WindowDataLayer : public Layer<Dtype> {
+  // The function used to perform prefetching.
+  friend void* WindowDataLayerPrefetch<Dtype>(void* layer_pointer);
+
+ public:
+  explicit WindowDataLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual ~WindowDataLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+
+  virtual void CreatePrefetchThread();
+  virtual void JoinPrefetchThread();
+  virtual unsigned int PrefetchRand();
+
+  shared_ptr<Caffe::RNG> prefetch_rng_;
+  pthread_t thread_;
+  shared_ptr<Blob<Dtype> > prefetch_data_;
+  shared_ptr<Blob<Dtype> > prefetch_label_;
+  Blob<Dtype> data_mean_;
+  vector<std::pair<std::string, vector<int> > > image_database_;
+  enum WindowField { IMAGE_INDEX, LABEL, OVERLAP, X1, Y1, X2, Y2, NUM };
+  vector<vector<float> > fg_windows_;
+  vector<vector<float> > bg_windows_;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_DATA_LAYERS_HPP_
diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp
new file mode 100644 (file)
index 0000000..6ddfcc4
--- /dev/null
@@ -0,0 +1,150 @@
+// Copyright 2014 BVLC and contributors.
+
+#ifndef CAFFE_LOSS_LAYERS_HPP_
+#define CAFFE_LOSS_LAYERS_HPP_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "leveldb/db.h"
+#include "pthread.h"
+#include "boost/scoped_ptr.hpp"
+#include "hdf5.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/neuron_layers.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+// LossLayer takes two inputs of same num, and has no output.
+template <typename Dtype>
+class LossLayer : public Layer<Dtype> {
+ public:
+  explicit LossLayer(const LayerParameter& param)
+     : Layer<Dtype>(param) {}
+  virtual void SetUp(
+      const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top);
+  virtual void FurtherSetUp(
+      const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top);
+};
+
+// SigmoidCrossEntropyLossLayer
+template <typename Dtype>
+class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
+ public:
+  explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param)
+      : LossLayer<Dtype>(param),
+          sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
+          sigmoid_output_(new Blob<Dtype>()) {}
+  virtual void FurtherSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
+  // sigmoid_output stores the output of the sigmoid layer.
+  shared_ptr<Blob<Dtype> > sigmoid_output_;
+  // Vector holders to call the underlying sigmoid layer forward and backward.
+  vector<Blob<Dtype>*> sigmoid_bottom_vec_;
+  vector<Blob<Dtype>*> sigmoid_top_vec_;
+};
+
+// EuclideanLossLayer: compute y = 1/2 \sum_i (x_i - x'_i)^2
+template <typename Dtype>
+class EuclideanLossLayer : public LossLayer<Dtype> {
+ public:
+  explicit EuclideanLossLayer(const LayerParameter& param)
+      : LossLayer<Dtype>(param), difference_() {}
+  virtual void FurtherSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  Blob<Dtype> difference_;
+};
+
+
+template <typename Dtype>
+class InfogainLossLayer : public LossLayer<Dtype> {
+ public:
+  explicit InfogainLossLayer(const LayerParameter& param)
+      : LossLayer<Dtype>(param), infogain_() {}
+  virtual void FurtherSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  Blob<Dtype> infogain_;
+};
+
+template <typename Dtype>
+class HingeLossLayer : public LossLayer<Dtype> {
+ public:
+  explicit HingeLossLayer(const LayerParameter& param)
+      : LossLayer<Dtype>(param) {}
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+template <typename Dtype>
+class MultinomialLogisticLossLayer : public LossLayer<Dtype> {
+ public:
+  explicit MultinomialLogisticLossLayer(const LayerParameter& param)
+      : LossLayer<Dtype>(param) {}
+  virtual void FurtherSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+// AccuracyLayer: not an actual loss layer;
+// computes the accuracy and logprob of x with respect to y'.
+template <typename Dtype>
+class AccuracyLayer : public Layer<Dtype> {
+ public:
+  explicit AccuracyLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  // The accuracy layer should not be used to compute backward operations.
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    NOT_IMPLEMENTED;
+  }
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_LOSS_LAYERS_HPP_
diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp
new file mode 100644 (file)
index 0000000..68ebc2b
--- /dev/null
@@ -0,0 +1,160 @@
+// Copyright 2014 BVLC and contributors.
+
+#ifndef CAFFE_NEURON_LAYERS_HPP_
+#define CAFFE_NEURON_LAYERS_HPP_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "leveldb/db.h"
+#include "pthread.h"
+#include "boost/scoped_ptr.hpp"
+#include "hdf5.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#define HDF5_DATA_DATASET_NAME "data"
+#define HDF5_DATA_LABEL_NAME "label"
+
+namespace caffe {
+
+// NeuronLayer is an interface for layers that work on single elements.
+template <typename Dtype>
+class NeuronLayer : public Layer<Dtype> {
+ public:
+  explicit NeuronLayer(const LayerParameter& param)
+     : Layer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+};
+
+// BNLLLayer
+template <typename Dtype>
+class BNLLLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit BNLLLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+// DropoutLayer sets random portion of its input to 0.
+template <typename Dtype>
+class DropoutLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit DropoutLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  shared_ptr<SyncedMemory> rand_vec_;
+  Dtype threshold_;
+  Dtype scale_;
+  unsigned int uint_thres_;
+};
+
+// PowerLayer computes y = (shift + scale * x)^power
+template <typename Dtype>
+class PowerLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit PowerLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  Dtype power_;
+  Dtype scale_;
+  Dtype shift_;
+  Dtype diff_scale_;
+};
+
+// ReLULayer computes y = max(0, x).
+template <typename Dtype>
+class ReLULayer : public NeuronLayer<Dtype> {
+ public:
+  explicit ReLULayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+// SigmoidLayer computes y = 1. / (1 + exp(-x))
+template <typename Dtype>
+class SigmoidLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit SigmoidLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+// TanHLayer: computes y = 1. * (exp(2 * x) - 1) / (exp(2 * x) + 1)
+template <typename Dtype>
+class TanHLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit TanHLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_NEURON_LAYERS_HPP_
index 4765398..7334ada 100644 (file)
 #include <utility>
 #include <vector>
 
-#include "leveldb/db.h"
-#include "pthread.h"
-#include "boost/scoped_ptr.hpp"
-#include "hdf5.h"
-
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
+#include "caffe/neuron_layers.hpp"
+#include "caffe/loss_layers.hpp"
+#include "caffe/data_layers.hpp"
 #include "caffe/proto/caffe.pb.h"
 
-#define HDF5_DATA_DATASET_NAME "data"
-#define HDF5_DATA_LABEL_NAME "label"
-
 namespace caffe {
 
-
-// The neuron layer is a specific type of layers that just works on single
-// celements.
-template <typename Dtype>
-class NeuronLayer : public Layer<Dtype> {
- public:
-  explicit NeuronLayer(const LayerParameter& param)
-     : Layer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-};
-
-template <typename Dtype>
-class BNLLLayer : public NeuronLayer<Dtype> {
- public:
-  explicit BNLLLayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-template <typename Dtype>
-class DropoutLayer : public NeuronLayer<Dtype> {
- public:
-  explicit DropoutLayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  shared_ptr<SyncedMemory> rand_vec_;
-  Dtype threshold_;
-  Dtype scale_;
-  unsigned int uint_thres_;
-};
-
-template <typename Dtype>
-class PowerLayer : public NeuronLayer<Dtype> {
- public:
-  explicit PowerLayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  Dtype power_;
-  Dtype scale_;
-  Dtype shift_;
-  Dtype diff_scale_;
-};
-
-template <typename Dtype>
-class ReLULayer : public NeuronLayer<Dtype> {
- public:
-  explicit ReLULayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-template <typename Dtype>
-class SigmoidLayer : public NeuronLayer<Dtype> {
- public:
-  explicit SigmoidLayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-template <typename Dtype>
-class SigmoidCrossEntropyLossLayer : public Layer<Dtype> {
- public:
-  explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param)
-      : Layer<Dtype>(param),
-          sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
-          sigmoid_output_(new Blob<Dtype>()) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
-  // sigmoid_output stores the output of the sigmoid layer.
-  shared_ptr<Blob<Dtype> > sigmoid_output_;
-  // Vector holders to call the underlying sigmoid layer forward and backward.
-  vector<Blob<Dtype>*> sigmoid_bottom_vec_;
-  vector<Blob<Dtype>*> sigmoid_top_vec_;
-};
-
-template <typename Dtype>
-class TanHLayer : public NeuronLayer<Dtype> {
- public:
-  explicit TanHLayer(const LayerParameter& param)
-      : NeuronLayer<Dtype>(param) {}
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-
-template <typename Dtype>
-class AccuracyLayer : public Layer<Dtype> {
- public:
-  explicit AccuracyLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  // The accuracy layer should not be used to compute backward operations.
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-    NOT_IMPLEMENTED;
-  }
-};
-
 template <typename Dtype>
 class ConcatLayer : public Layer<Dtype> {
  public:
@@ -260,50 +79,6 @@ class ConvolutionLayer : public Layer<Dtype> {
   int N_;
 };
 
-// This function is used to create a pthread that prefetches the data.
-template <typename Dtype>
-void* DataLayerPrefetch(void* layer_pointer);
-
-template <typename Dtype>
-class DataLayer : public Layer<Dtype> {
-  // The function used to perform prefetching.
-  friend void* DataLayerPrefetch<Dtype>(void* layer_pointer);
-
- public:
-  explicit DataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual ~DataLayer();
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
-  virtual unsigned int PrefetchRand();
-
-  shared_ptr<Caffe::RNG> prefetch_rng_;
-  shared_ptr<leveldb::DB> db_;
-  shared_ptr<leveldb::Iterator> iter_;
-  int datum_channels_;
-  int datum_height_;
-  int datum_width_;
-  int datum_size_;
-  pthread_t thread_;
-  shared_ptr<Blob<Dtype> > prefetch_data_;
-  shared_ptr<Blob<Dtype> > prefetch_label_;
-  Blob<Dtype> data_mean_;
-  bool output_labels_;
-  Caffe::Phase phase_;
-};
 
 template <typename Dtype>
 class EltwiseProductLayer : public Layer<Dtype> {
@@ -325,27 +100,6 @@ class EltwiseProductLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
-class EuclideanLossLayer : public Layer<Dtype> {
- public:
-  explicit EuclideanLossLayer(const LayerParameter& param)
-      : Layer<Dtype>(param), difference_() {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  // virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //     vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  // virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  Blob<Dtype> difference_;
-};
-
-template <typename Dtype>
 class FlattenLayer : public Layer<Dtype> {
  public:
   explicit FlattenLayer(const LayerParameter& param)
@@ -367,79 +121,6 @@ class FlattenLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
-class HDF5OutputLayer : public Layer<Dtype> {
- public:
-  explicit HDF5OutputLayer(const LayerParameter& param);
-  virtual ~HDF5OutputLayer();
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  inline std::string file_name() const { return file_name_; }
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void SaveBlobs();
-
-  std::string file_name_;
-  hid_t file_id_;
-  Blob<Dtype> data_blob_;
-  Blob<Dtype> label_blob_;
-};
-
-template <typename Dtype>
-class HDF5DataLayer : public Layer<Dtype> {
- public:
-  explicit HDF5DataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual ~HDF5DataLayer();
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  virtual void LoadHDF5FileData(const char* filename);
-
-  std::vector<std::string> hdf_filenames_;
-  unsigned int num_files_;
-  unsigned int current_file_;
-  hsize_t current_row_;
-  Blob<Dtype> data_blob_;
-  Blob<Dtype> label_blob_;
-};
-
-template <typename Dtype>
-class HingeLossLayer : public Layer<Dtype> {
- public:
-  explicit HingeLossLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  // virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //     vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  // virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-template <typename Dtype>
 class Im2colLayer : public Layer<Dtype> {
  public:
   explicit Im2colLayer(const LayerParameter& param)
@@ -465,73 +146,6 @@ class Im2colLayer : public Layer<Dtype> {
   int pad_;
 };
 
-// This function is used to create a pthread that prefetches the data.
-template <typename Dtype>
-void* ImageDataLayerPrefetch(void* layer_pointer);
-
-template <typename Dtype>
-class ImageDataLayer : public Layer<Dtype> {
-  // The function used to perform prefetching.
-  friend void* ImageDataLayerPrefetch<Dtype>(void* layer_pointer);
-
- public:
-  explicit ImageDataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual ~ImageDataLayer();
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-
-  virtual void ShuffleImages();
-
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
-  virtual unsigned int PrefetchRand();
-
-  shared_ptr<Caffe::RNG> prefetch_rng_;
-  vector<std::pair<std::string, int> > lines_;
-  int lines_id_;
-  int datum_channels_;
-  int datum_height_;
-  int datum_width_;
-  int datum_size_;
-  pthread_t thread_;
-  shared_ptr<Blob<Dtype> > prefetch_data_;
-  shared_ptr<Blob<Dtype> > prefetch_label_;
-  Blob<Dtype> data_mean_;
-  Caffe::Phase phase_;
-};
-
-template <typename Dtype>
-class InfogainLossLayer : public Layer<Dtype> {
- public:
-  explicit InfogainLossLayer(const LayerParameter& param)
-      : Layer<Dtype>(param), infogain_() {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  // virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //     vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  // virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  Blob<Dtype> infogain_;
-};
-
 template <typename Dtype>
 class InnerProductLayer : public Layer<Dtype> {
  public:
@@ -659,25 +273,6 @@ class MemoryDataLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
-class MultinomialLogisticLossLayer : public Layer<Dtype> {
- public:
-  explicit MultinomialLogisticLossLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  // virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //     vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  // virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-};
-
-template <typename Dtype>
 class PoolingLayer : public Layer<Dtype> {
  public:
   explicit PoolingLayer(const LayerParameter& param)
@@ -782,48 +377,6 @@ class SplitLayer : public Layer<Dtype> {
   int count_;
 };
 
-// This function is used to create a pthread that prefetches the window data.
-template <typename Dtype>
-void* WindowDataLayerPrefetch(void* layer_pointer);
-
-template <typename Dtype>
-class WindowDataLayer : public Layer<Dtype> {
-  // The function used to perform prefetching.
-  friend void* WindowDataLayerPrefetch<Dtype>(void* layer_pointer);
-
- public:
-  explicit WindowDataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
-  virtual ~WindowDataLayer();
-  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-
- protected:
-  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
-
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
-  virtual unsigned int PrefetchRand();
-
-  shared_ptr<Caffe::RNG> prefetch_rng_;
-  pthread_t thread_;
-  shared_ptr<Blob<Dtype> > prefetch_data_;
-  shared_ptr<Blob<Dtype> > prefetch_label_;
-  Blob<Dtype> data_mean_;
-  vector<std::pair<std::string, vector<int> > > image_database_;
-  enum WindowField { IMAGE_INDEX, LABEL, OVERLAP, X1, Y1, X2, Y2, NUM };
-  vector<vector<float> > fg_windows_;
-  vector<vector<float> > bg_windows_;
-};
-
-
 }  // namespace caffe
 
 #endif  // CAFFE_VISION_LAYERS_HPP_
index b7f812a..3fc34a6 100644 (file)
@@ -17,12 +17,24 @@ namespace caffe {
 const float kLOG_THRESHOLD = 1e-20;
 
 template <typename Dtype>
-void MultinomialLogisticLossLayer<Dtype>::SetUp(
+void LossLayer<Dtype>::SetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
   CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
   CHECK_EQ(bottom[0]->num(), bottom[1]->num())
       << "The data and label should have the same number.";
+  FurtherSetUp(bottom, top);
+}
+
+template <typename Dtype>
+void LossLayer<Dtype>::FurtherSetUp(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  // Nothing to do
+}
+
+template <typename Dtype>
+void MultinomialLogisticLossLayer<Dtype>::FurtherSetUp(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom[1]->channels(), 1);
   CHECK_EQ(bottom[1]->height(), 1);
   CHECK_EQ(bottom[1]->width(), 1);
@@ -63,18 +75,15 @@ void MultinomialLogisticLossLayer<Dtype>::Backward_cpu(
 
 
 template <typename Dtype>
-void InfogainLossLayer<Dtype>::SetUp(
+void InfogainLossLayer<Dtype>::FurtherSetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
-  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
-      << "The data and label should have the same number.";
   CHECK_EQ(bottom[1]->channels(), 1);
   CHECK_EQ(bottom[1]->height(), 1);
   CHECK_EQ(bottom[1]->width(), 1);
+
   BlobProto blob_proto;
-  ReadProtoFromBinaryFile(this->layer_param_.infogain_loss_param().source(),
-                          &blob_proto);
+  ReadProtoFromBinaryFile(
+    this->layer_param_.infogain_loss_param().source(), &blob_proto);
   infogain_.FromProto(blob_proto);
   CHECK_EQ(infogain_.num(), 1);
   CHECK_EQ(infogain_.channels(), 1);
@@ -124,12 +133,8 @@ void InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 
 
 template <typename Dtype>
-void EuclideanLossLayer<Dtype>::SetUp(
+void EuclideanLossLayer<Dtype>::FurtherSetUp(
   const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
-  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
-      << "The data and label should have the same number.";
   CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
   CHECK_EQ(bottom[0]->height(), bottom[1]->height());
   CHECK_EQ(bottom[0]->width(), bottom[1]->width());
@@ -160,6 +165,44 @@ void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 }
 
 template <typename Dtype>
+Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+  const Dtype* label = bottom[1]->cpu_data();
+  int num = bottom[0]->num();
+  int count = bottom[0]->count();
+  int dim = count / num;
+
+  caffe_copy(count, bottom_data, bottom_diff);
+  for (int i = 0; i < num; ++i) {
+    bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
+  }
+  for (int i = 0; i < num; ++i) {
+    for (int j = 0; j < dim; ++j) {
+      bottom_diff[i * dim + j] = max(Dtype(0), 1 + bottom_diff[i * dim + j]);
+    }
+  }
+  return caffe_cpu_asum(count, bottom_diff) / num;
+}
+
+template <typename Dtype>
+void HingeLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  const Dtype* label = (*bottom)[1]->cpu_data();
+  int num = (*bottom)[0]->num();
+  int count = (*bottom)[0]->count();
+  int dim = count / num;
+
+  caffe_cpu_sign(count, bottom_diff, bottom_diff);
+  for (int i = 0; i < num; ++i) {
+    bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
+  }
+  caffe_scal(count, Dtype(1. / num), bottom_diff);
+}
+
+template <typename Dtype>
 void AccuracyLayer<Dtype>::SetUp(
   const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 2) << "Accuracy Layer takes two blobs as input.";
@@ -205,55 +248,10 @@ Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   return Dtype(0);
 }
 
-template <typename Dtype>
-void HingeLossLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 2) << "Hinge Loss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) << "Hinge Loss Layer takes no output.";
-}
-
-template <typename Dtype>
-Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0]->cpu_data();
-  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
-  const Dtype* label = bottom[1]->cpu_data();
-  int num = bottom[0]->num();
-  int count = bottom[0]->count();
-  int dim = count / num;
-
-  caffe_copy(count, bottom_data, bottom_diff);
-  for (int i = 0; i < num; ++i) {
-    bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
-  }
-  for (int i = 0; i < num; ++i) {
-    for (int j = 0; j < dim; ++j) {
-      bottom_diff[i * dim + j] = max(Dtype(0), 1 + bottom_diff[i * dim + j]);
-    }
-  }
-  return caffe_cpu_asum(count, bottom_diff) / num;
-}
-
-template <typename Dtype>
-void HingeLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  const Dtype* label = (*bottom)[1]->cpu_data();
-  int num = (*bottom)[0]->num();
-  int count = (*bottom)[0]->count();
-  int dim = count / num;
-
-  caffe_cpu_sign(count, bottom_diff, bottom_diff);
-  for (int i = 0; i < num; ++i) {
-    bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
-  }
-  caffe_scal(count, Dtype(1. / num), bottom_diff);
-}
-
 INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
 INSTANTIATE_CLASS(InfogainLossLayer);
 INSTANTIATE_CLASS(EuclideanLossLayer);
-INSTANTIATE_CLASS(AccuracyLayer);
 INSTANTIATE_CLASS(HingeLossLayer);
+INSTANTIATE_CLASS(AccuracyLayer);
 
 }  // namespace caffe
index 767601c..a638684 100644 (file)
@@ -13,16 +13,10 @@ using std::max;
 namespace caffe {
 
 template <typename Dtype>
-void SigmoidCrossEntropyLossLayer<Dtype>::SetUp(
+void SigmoidCrossEntropyLossLayer<Dtype>::FurtherSetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 2) <<
-      "SigmoidCrossEntropyLoss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) <<
-      "SigmoidCrossEntropyLoss Layer takes no blob as output.";
   CHECK_EQ(bottom[0]->count(), bottom[1]->count()) <<
       "SigmoidCrossEntropyLoss Layer inputs must have same count.";
-  CHECK_EQ(bottom[0]->num(), bottom[1]->num()) <<
-      "SigmoidCrossEntropyLoss Layer inputs must have same num.";
   sigmoid_bottom_vec_.clear();
   sigmoid_bottom_vec_.push_back(bottom[0]);
   sigmoid_top_vec_.clear();