protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+};
+
+template <typename Dtype>
+class CaffeSoftmaxLayer : public SoftmaxLayer<Dtype> {
+ public:
+ explicit CaffeSoftmaxLayer(const LayerParameter& param)
+ : SoftmaxLayer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ virtual inline LayerParameter_LayerType type() const {
+ return LayerParameter_LayerType_SOFTMAX;
+ }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
vector<Blob<Dtype>*> sigmoid_top_vec_;
};
-// Forward declare SoftmaxLayer for use in SoftmaxWithLossLayer.
-template <typename Dtype> class SoftmaxLayer;
+// Forward declare CaffeSoftmaxLayer for use in SoftmaxWithLossLayer.
+template <typename Dtype> class CaffeSoftmaxLayer;
/**
* @brief Computes the multinomial logistic loss for a one-of-many
public:
explicit SoftmaxWithLossLayer(const LayerParameter& param)
: LossLayer<Dtype>(param),
- softmax_layer_(new SoftmaxLayer<Dtype>(param)) {}
+ softmax_layer_(new CaffeSoftmaxLayer<Dtype>(param)) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
/// The internal SoftmaxLayer used to map predictions to a distribution.
- shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
+ shared_ptr<CaffeSoftmaxLayer<Dtype> > softmax_layer_;
/// prob stores the output probability predictions from the SoftmaxLayer.
Blob<Dtype> prob_;
/// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
template TanHLayer<double>* GetTanHLayer(const string& name,
const LayerParameter& param);
+// Get softmax layer according to engine.
+template <typename Dtype>
+SoftmaxLayer<Dtype>* GetSoftmaxLayer(const string& name,
+ const LayerParameter& param) {
+ SoftmaxParameter_Engine engine = param.softmax_param().engine();
+ if (engine == SoftmaxParameter_Engine_CAFFE) {
+ return new CaffeSoftmaxLayer<Dtype>(param);
+ } else {
+ LOG(FATAL) << "Layer " << name << " has unknown engine.";
+ }
+}
+
+template SoftmaxLayer<float>* GetSoftmaxLayer(const string& name,
+ const LayerParameter& param);
+template SoftmaxLayer<double>* GetSoftmaxLayer(const string& name,
+ const LayerParameter& param);
+
// A function to get a specific layer from the specification given in
// LayerParameter. Ideally this would be replaced by a factory pattern,
// but we will leave it this way for now.
case LayerParameter_LayerType_SLICE:
return new SliceLayer<Dtype>(param);
case LayerParameter_LayerType_SOFTMAX:
- return new SoftmaxLayer<Dtype>(param);
+ return GetSoftmaxLayer<Dtype>(name, param);
case LayerParameter_LayerType_SOFTMAX_LOSS:
return new SoftmaxWithLossLayer<Dtype>(param);
case LayerParameter_LayerType_SPLIT:
--- /dev/null
+//
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ SoftmaxLayer<Dtype>::LayerSetUp(bottom, top);
+ sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
+ Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+ for (int i = 0; i < sum_multiplier_.count(); ++i) {
+ multiplier_data[i] = 1.;
+ }
+ scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
+}
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = bottom[0]->num();
+ int channels = bottom[0]->channels();
+ int dim = bottom[0]->count() / bottom[0]->num();
+ int spatial_dim = bottom[0]->height() * bottom[0]->width();
+ caffe_copy(bottom[0]->count(), bottom_data, top_data);
+ // We need to subtract the max to avoid numerical issues, compute the exp,
+ // and then normalize.
+ for (int i = 0; i < num; ++i) {
+ // initialize scale_data to the first plane
+ caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
+ for (int j = 0; j < channels; j++) {
+ for (int k = 0; k < spatial_dim; k++) {
+ scale_data[k] = std::max(scale_data[k],
+ bottom_data[i * dim + j * spatial_dim + k]);
+ }
+ }
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
+ 1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
+ // exponentiation
+ caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
+ // sum after exp
+ caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
+ top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
+ // division
+ for (int j = 0; j < channels; j++) {
+ caffe_div(spatial_dim, top_data + (*top)[0]->offset(i, j), scale_data,
+ top_data + (*top)[0]->offset(i, j));
+ }
+ }
+}
+
+template <typename Dtype>
+void CaffeSoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* top_data = top[0]->cpu_data();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = top[0]->num();
+ int channels = top[0]->channels();
+ int dim = top[0]->count() / top[0]->num();
+ int spatial_dim = top[0]->height() * top[0]->width();
+ caffe_copy(top[0]->count(), top_diff, bottom_diff);
+ for (int i = 0; i < num; ++i) {
+ // compute dot(top_diff, top_data) and subtract them from the bottom diff
+ for (int k = 0; k < spatial_dim; ++k) {
+ scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
+ bottom_diff + i * dim + k, spatial_dim,
+ top_data + i * dim + k, spatial_dim);
+ }
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
+ -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
+ }
+ // elementwise multiplication
+ caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(CaffeSoftmaxLayer);
+#endif
+
+INSTANTIATE_CLASS(CaffeSoftmaxLayer);
+
+
+} // namespace caffe
}
template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
}
template <typename Dtype>
-void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffeSoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
}
-INSTANTIATE_CLASS(SoftmaxLayer);
-
+INSTANTIATE_CLASS(CaffeSoftmaxLayer);
} // namespace caffe
vector<Blob<Dtype>*>* top) {
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
- sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
- Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
- for (int i = 0; i < sum_multiplier_.count(); ++i) {
- multiplier_data[i] = 1.;
- }
- scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
}
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- Dtype* top_data = (*top)[0]->mutable_cpu_data();
- Dtype* scale_data = scale_.mutable_cpu_data();
- int num = bottom[0]->num();
- int channels = bottom[0]->channels();
- int dim = bottom[0]->count() / bottom[0]->num();
- int spatial_dim = bottom[0]->height() * bottom[0]->width();
- caffe_copy(bottom[0]->count(), bottom_data, top_data);
- // We need to subtract the max to avoid numerical issues, compute the exp,
- // and then normalize.
- for (int i = 0; i < num; ++i) {
- // initialize scale_data to the first plane
- caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
- for (int j = 0; j < channels; j++) {
- for (int k = 0; k < spatial_dim; k++) {
- scale_data[k] = std::max(scale_data[k],
- bottom_data[i * dim + j * spatial_dim + k]);
- }
- }
- // subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
- 1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
- // exponentiation
- caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
- // sum after exp
- caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
- top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
- // division
- for (int j = 0; j < channels; j++) {
- caffe_div(spatial_dim, top_data + (*top)[0]->offset(i, j), scale_data,
- top_data + (*top)[0]->offset(i, j));
- }
- }
-}
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const vector<bool>& propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- const Dtype* top_diff = top[0]->cpu_diff();
- const Dtype* top_data = top[0]->cpu_data();
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- Dtype* scale_data = scale_.mutable_cpu_data();
- int num = top[0]->num();
- int channels = top[0]->channels();
- int dim = top[0]->count() / top[0]->num();
- int spatial_dim = top[0]->height() * top[0]->width();
- caffe_copy(top[0]->count(), top_diff, bottom_diff);
- for (int i = 0; i < num; ++i) {
- // compute dot(top_diff, top_data) and subtract them from the bottom diff
- for (int k = 0; k < spatial_dim; ++k) {
- scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
- bottom_diff + i * dim + k, spatial_dim,
- top_data + i * dim + k, spatial_dim);
- }
- // subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
- -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
- }
- // elementwise multiplication
- caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
-}
-
-
-#ifdef CPU_ONLY
-STUB_GPU(SoftmaxLayer);
-#endif
-
INSTANTIATE_CLASS(SoftmaxLayer);
-
} // namespace caffe
TYPED_TEST(SoftmaxLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- SoftmaxLayer<Dtype> layer(layer_param);
+ CaffeSoftmaxLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test sum
TYPED_TEST(SoftmaxLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- SoftmaxLayer<Dtype> layer(layer_param);
+ CaffeSoftmaxLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));