strategize softmax
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 2 Sep 2014 04:29:50 +0000 (21:29 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 7 Sep 2014 01:27:07 +0000 (03:27 +0200)
include/caffe/common_layers.hpp
include/caffe/loss_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/caffe_softmax_layer.cpp [new file with mode: 0644]
src/caffe/layers/caffe_softmax_layer.cu [moved from src/caffe/layers/softmax_layer.cu with 96% similarity]
src/caffe/layers/softmax_layer.cpp
src/caffe/test/test_softmax_layer.cpp

index 3753592..b3f6981 100644 (file)
@@ -361,6 +361,31 @@ class SoftmaxLayer : public Layer<Dtype> {
 
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+     const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+};
+
+template <typename Dtype>
+class CaffeSoftmaxLayer : public SoftmaxLayer<Dtype> {
+ public:
+  explicit CaffeSoftmaxLayer(const LayerParameter& param)
+      : SoftmaxLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual inline LayerParameter_LayerType type() const {
+    return LayerParameter_LayerType_SOFTMAX;
+  }
+  virtual inline int ExactNumBottomBlobs() const { return 1; }
+  virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
index 555c06e..30a4c37 100644 (file)
@@ -567,8 +567,8 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   vector<Blob<Dtype>*> sigmoid_top_vec_;
 };
 
-// Forward declare SoftmaxLayer for use in SoftmaxWithLossLayer.
-template <typename Dtype> class SoftmaxLayer;
+// Forward declare CaffeSoftmaxLayer for use in SoftmaxWithLossLayer.
+template <typename Dtype> class CaffeSoftmaxLayer;
 
 /**
  * @brief Computes the multinomial logistic loss for a one-of-many
@@ -603,7 +603,7 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
  public:
   explicit SoftmaxWithLossLayer(const LayerParameter& param)
       : LossLayer<Dtype>(param),
-        softmax_layer_(new SoftmaxLayer<Dtype>(param)) {}
+        softmax_layer_(new CaffeSoftmaxLayer<Dtype>(param)) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
@@ -657,7 +657,7 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   /// The internal SoftmaxLayer used to map predictions to a distribution.
-  shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
+  shared_ptr<CaffeSoftmaxLayer<Dtype> > softmax_layer_;
   /// prob stores the output probability predictions from the SoftmaxLayer.
   Blob<Dtype> prob_;
   /// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
index 3a535f5..afb530a 100644 (file)
@@ -94,6 +94,23 @@ template TanHLayer<float>* GetTanHLayer(const string& name,
 template TanHLayer<double>* GetTanHLayer(const string& name,
     const LayerParameter& param);
 
+// Get softmax layer according to engine.
+template <typename Dtype>
+SoftmaxLayer<Dtype>* GetSoftmaxLayer(const string& name,
+    const LayerParameter& param) {
+  SoftmaxParameter_Engine engine = param.softmax_param().engine();
+  if (engine == SoftmaxParameter_Engine_CAFFE) {
+    return new CaffeSoftmaxLayer<Dtype>(param);
+  } else {
+    LOG(FATAL) << "Layer " << name << " has unknown engine.";
+  }
+}
+
+template SoftmaxLayer<float>* GetSoftmaxLayer(const string& name,
+    const LayerParameter& param);
+template SoftmaxLayer<double>* GetSoftmaxLayer(const string& name,
+    const LayerParameter& param);
+
 // A function to get a specific layer from the specification given in
 // LayerParameter. Ideally this would be replaced by a factory pattern,
 // but we will leave it this way for now.
@@ -163,7 +180,7 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   case LayerParameter_LayerType_SLICE:
     return new SliceLayer<Dtype>(param);
   case LayerParameter_LayerType_SOFTMAX:
-    return new SoftmaxLayer<Dtype>(param);
+    return GetSoftmaxLayer<Dtype>(name, param);
   case LayerParameter_LayerType_SOFTMAX_LOSS:
     return new SoftmaxWithLossLayer<Dtype>(param);
   case LayerParameter_LayerType_SPLIT:
diff --git a/src/caffe/layers/caffe_softmax_layer.cpp b/src/caffe/layers/caffe_softmax_layer.cpp
new file mode 100644 (file)
index 0000000..64d027c
--- /dev/null
@@ -0,0 +1,97 @@
+//
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  SoftmaxLayer<Dtype>::LayerSetUp(bottom, top);
+  sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
+  Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+  for (int i = 0; i < sum_multiplier_.count(); ++i) {
+    multiplier_data[i] = 1.;
+  }
+  scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
+}
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = bottom[0]->num();
+  int channels = bottom[0]->channels();
+  int dim = bottom[0]->count() / bottom[0]->num();
+  int spatial_dim = bottom[0]->height() * bottom[0]->width();
+  caffe_copy(bottom[0]->count(), bottom_data, top_data);
+  // We need to subtract the max to avoid numerical issues, compute the exp,
+  // and then normalize.
+  for (int i = 0; i < num; ++i) {
+    // initialize scale_data to the first plane
+    caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
+    for (int j = 0; j < channels; j++) {
+      for (int k = 0; k < spatial_dim; k++) {
+        scale_data[k] = std::max(scale_data[k],
+            bottom_data[i * dim + j * spatial_dim + k]);
+      }
+    }
+    // subtraction
+    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
+        1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
+    // exponentiation
+    caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
+    // sum after exp
+    caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
+        top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
+    // division
+    for (int j = 0; j < channels; j++) {
+      caffe_div(spatial_dim, top_data + (*top)[0]->offset(i, j), scale_data,
+          top_data + (*top)[0]->offset(i, j));
+    }
+  }
+}
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  const Dtype* top_data = top[0]->cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = top[0]->num();
+  int channels = top[0]->channels();
+  int dim = top[0]->count() / top[0]->num();
+  int spatial_dim = top[0]->height() * top[0]->width();
+  caffe_copy(top[0]->count(), top_diff, bottom_diff);
+  for (int i = 0; i < num; ++i) {
+    // compute dot(top_diff, top_data) and subtract them from the bottom diff
+    for (int k = 0; k < spatial_dim; ++k) {
+      scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
+          bottom_diff + i * dim + k, spatial_dim,
+          top_data + i * dim + k, spatial_dim);
+    }
+    // subtraction
+    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
+        -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
+  }
+  // elementwise multiplication
+  caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(CaffeSoftmaxLayer);
+#endif
+
+INSTANTIATE_CLASS(CaffeSoftmaxLayer);
+
+
+}  // namespace caffe
similarity index 96%
rename from src/caffe/layers/softmax_layer.cu
rename to src/caffe/layers/caffe_softmax_layer.cu
index f97eafc..74f6a7d 100644 (file)
@@ -86,7 +86,7 @@ __global__ void kernel_channel_dot(const int num, const int channels,
 }
 
 template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -125,7 +125,7 @@ void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffeSoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* top_data = top[0]->gpu_data();
@@ -148,7 +148,6 @@ void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
 }
 
-INSTANTIATE_CLASS(SoftmaxLayer);
-
+INSTANTIATE_CLASS(CaffeSoftmaxLayer);
 
 }  // namespace caffe
index 29767ac..06b5e2b 100644 (file)
@@ -13,86 +13,8 @@ void SoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
       bottom[0]->height(), bottom[0]->width());
-  sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
-  Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
-  for (int i = 0; i < sum_multiplier_.count(); ++i) {
-    multiplier_data[i] = 1.;
-  }
-  scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
 }
 
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0]->cpu_data();
-  Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  Dtype* scale_data = scale_.mutable_cpu_data();
-  int num = bottom[0]->num();
-  int channels = bottom[0]->channels();
-  int dim = bottom[0]->count() / bottom[0]->num();
-  int spatial_dim = bottom[0]->height() * bottom[0]->width();
-  caffe_copy(bottom[0]->count(), bottom_data, top_data);
-  // We need to subtract the max to avoid numerical issues, compute the exp,
-  // and then normalize.
-  for (int i = 0; i < num; ++i) {
-    // initialize scale_data to the first plane
-    caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
-    for (int j = 0; j < channels; j++) {
-      for (int k = 0; k < spatial_dim; k++) {
-        scale_data[k] = std::max(scale_data[k],
-            bottom_data[i * dim + j * spatial_dim + k]);
-      }
-    }
-    // subtraction
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
-        1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
-    // exponentiation
-    caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
-    // sum after exp
-    caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
-        top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
-    // division
-    for (int j = 0; j < channels; j++) {
-      caffe_div(spatial_dim, top_data + (*top)[0]->offset(i, j), scale_data,
-          top_data + (*top)[0]->offset(i, j));
-    }
-  }
-}
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const vector<bool>& propagate_down,
-    vector<Blob<Dtype>*>* bottom) {
-  const Dtype* top_diff = top[0]->cpu_diff();
-  const Dtype* top_data = top[0]->cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  Dtype* scale_data = scale_.mutable_cpu_data();
-  int num = top[0]->num();
-  int channels = top[0]->channels();
-  int dim = top[0]->count() / top[0]->num();
-  int spatial_dim = top[0]->height() * top[0]->width();
-  caffe_copy(top[0]->count(), top_diff, bottom_diff);
-  for (int i = 0; i < num; ++i) {
-    // compute dot(top_diff, top_data) and subtract them from the bottom diff
-    for (int k = 0; k < spatial_dim; ++k) {
-      scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
-          bottom_diff + i * dim + k, spatial_dim,
-          top_data + i * dim + k, spatial_dim);
-    }
-    // subtraction
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
-        -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
-  }
-  // elementwise multiplication
-  caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
-}
-
-
-#ifdef CPU_ONLY
-STUB_GPU(SoftmaxLayer);
-#endif
-
 INSTANTIATE_CLASS(SoftmaxLayer);
 
-
 }  // namespace caffe
index 9f45f76..18c68b3 100644 (file)
@@ -40,7 +40,7 @@ TYPED_TEST_CASE(SoftmaxLayerTest, TestDtypesAndDevices);
 TYPED_TEST(SoftmaxLayerTest, TestForward) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  SoftmaxLayer<Dtype> layer(layer_param);
+  CaffeSoftmaxLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Test sum
@@ -74,7 +74,7 @@ TYPED_TEST(SoftmaxLayerTest, TestForward) {
 TYPED_TEST(SoftmaxLayerTest, TestGradient) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  SoftmaxLayer<Dtype> layer(layer_param);
+  CaffeSoftmaxLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));