inner product bugfix
authorYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 17:02:20 +0000 (10:02 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 17:02:20 +0000 (10:02 -0700)
src/caffe/layers/inner_product_layer.cpp

index 115dba7..e5e8e52 100644 (file)
@@ -73,7 +73,7 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   const Dtype* top_diff = top[0]->cpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
   // Gradient with respect to weight
-  caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+  caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
       top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias
@@ -112,7 +112,7 @@ Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
   // Gradient with respect to weight
-  caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+  caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
       top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias