inner product bugfix: not tested yet
authorYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 16:48:37 +0000 (09:48 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 16:48:37 +0000 (09:48 -0700)
src/caffe/filler.hpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/multinomial_logistic_loss_layer.cpp [deleted file]

index e945307..f455460 100644 (file)
@@ -97,6 +97,16 @@ class PositiveUnitballFiller : public Filler<Dtype> {
   }
 };
 
+template <typename Dtype>
+class XavierFiller : public Filler<Dtype> {
+ public:
+  explicit XavierFiller(const FillerParameter& param)
+      : Filler<Dtype>(param) {}
+  virtual void Fill(Blob<Dtype>* blob) {
+    
+  }
+};
+
 
 // A function to get a specific filler from the specification given in
 // FillerParameter. Ideally this would be replaced by a factory pattern,
index 55b51c6..115dba7 100644 (file)
@@ -57,7 +57,7 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
   const Dtype* weight = this->blobs_[0]->cpu_data();
-  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -74,7 +74,7 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
   // Gradient with respect to weight
   caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
-      bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
+      top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
@@ -83,7 +83,7 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   }
   if (propagate_down) {
     // Gradient with respect to bottom data
-    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
         top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
         (*bottom)[0]->mutable_cpu_diff());
   }
@@ -96,7 +96,7 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   const Dtype* weight = this->blobs_[0]->gpu_data();
-  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -113,7 +113,7 @@ Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
   // Gradient with respect to weight
   caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
-      bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
+      top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
@@ -122,7 +122,7 @@ Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   }
   if (propagate_down) {
     // Gradient with respect to bottom data
-    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+    caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
         top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
         (*bottom)[0]->mutable_gpu_diff());
   }
diff --git a/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/src/caffe/layers/multinomial_logistic_loss_layer.cpp
deleted file mode 100644 (file)
index 5ffa4ac..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
-#include "caffe/util/math_functions.hpp"
-#include <algorithm>
-#include <cmath>
-
-using std::max;
-
-namespace caffe {
-
-template <typename Dtype>
-void MultinomialLogisticLossLayer<Dtype>::SetUp(
-    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
-  CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
-  CHECK_EQ(bottom[0]->num(), bottom[1]->num())
-      << "The data and label should have the same number.";
-  CHECK_EQ(bottom[1]->channels(), 1);
-  CHECK_EQ(bottom[1]->height(), 1);
-  CHECK_EQ(bottom[1]->width(), 1);
-};
-
-
-template <typename Dtype>
-Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
-    vector<Blob<Dtype>*>* bottom) {
-  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
-  const Dtype* bottom_label = (*bottom)[1]->cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  int num = (*bottom)[0]->num();
-  int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
-  memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
-  Dtype loss = 0;
-  const Dtype kLOG_THRESHOLD = 1e-8;
-  for (int i = 0; i < num; ++i) {
-    int label = static_cast<int>(bottom_label[i]);
-    Dtype prob = max(bottom_data[i * dim + label], kLOG_THRESHOLD);
-    loss -= log(prob);
-    bottom_diff[i * dim + label] = - 1. / prob / num;
-  }
-  return loss / num;
-}
-
-// TODO: implement the GPU version
-
-INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
-
-
-}  // namespace caffe