SoftmaxWithLossLayer: use CreateLayer so that a CuDNNSoftmaxLayer
authorJeff Donahue <jeff.donahue@gmail.com>
Sun, 2 Nov 2014 04:37:05 +0000 (21:37 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Thu, 22 Jan 2015 00:53:40 +0000 (16:53 -0800)
is created if available

include/caffe/loss_layers.hpp
src/caffe/layers/softmax_loss_layer.cpp

index 13b108a..4350136 100644 (file)
@@ -697,8 +697,7 @@ template <typename Dtype>
 class SoftmaxWithLossLayer : public LossLayer<Dtype> {
  public:
   explicit SoftmaxWithLossLayer(const LayerParameter& param)
-      : LossLayer<Dtype>(param),
-        softmax_layer_(new SoftmaxLayer<Dtype>(param)) {}
+      : LossLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -751,7 +750,7 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
   /// The internal SoftmaxLayer used to map predictions to a distribution.
-  shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
+  shared_ptr<Layer<Dtype> > softmax_layer_;
   /// prob stores the output probability predictions from the SoftmaxLayer.
   Blob<Dtype> prob_;
   /// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
index db8dd8b..dfc41d2 100644 (file)
@@ -3,6 +3,7 @@
 #include <vector>
 
 #include "caffe/layer.hpp"
+#include "caffe/layer_factory.hpp"
 #include "caffe/util/math_functions.hpp"
 #include "caffe/vision_layers.hpp"
 
@@ -12,6 +13,9 @@ template <typename Dtype>
 void SoftmaxWithLossLayer<Dtype>::LayerSetUp(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   LossLayer<Dtype>::LayerSetUp(bottom, top);
+  LayerParameter softmax_param(this->layer_param_);
+  softmax_param.set_type(LayerParameter_LayerType_SOFTMAX);
+  softmax_layer_.reset(LayerRegistry<Dtype>::CreateLayer(softmax_param));
   softmax_bottom_vec_.clear();
   softmax_bottom_vec_.push_back(bottom[0]);
   softmax_top_vec_.clear();