use Blob directly instead of shared_ptr for InnerProductLayer::bias_multiplier_
authorJonathan L Long <jonlong@cs.berkeley.edu>
Sun, 20 Jul 2014 01:02:33 +0000 (18:02 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Sun, 20 Jul 2014 01:31:32 +0000 (18:31 -0700)
include/caffe/vision_layers.hpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/inner_product_layer.cu

index a261ac3..c995992 100644 (file)
@@ -157,7 +157,7 @@ class InnerProductLayer : public Layer<Dtype> {
   int K_;
   int N_;
   bool bias_term_;
-  shared_ptr<Blob<Dtype> > bias_multiplier_;
+  Blob<Dtype> bias_multiplier_;
 };
 
 // Forward declare PoolingLayer and SplitLayer for use in LRNLayer.
index 4da8578..5505d08 100644 (file)
@@ -47,8 +47,8 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   }  // parameter initialization
   // Setting up the bias multiplier
   if (bias_term_) {
-    bias_multiplier_.reset(new Blob<Dtype>(1, 1, 1, M_));
-    caffe_set(M_, Dtype(1), bias_multiplier_->mutable_cpu_data());
+    bias_multiplier_.Reshape(1, 1, 1, M_);
+    caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
   }
   this->param_propagate_down_.resize(this->blobs_.size(), true);
 }
@@ -63,7 +63,7 @@ Dtype InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       bottom_data, weight, (Dtype)0., top_data);
   if (bias_term_) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
-        bias_multiplier_->cpu_data(),
+        bias_multiplier_.cpu_data(),
         this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
   }
   return Dtype(0);
@@ -84,7 +84,7 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->cpu_diff();
     // Gradient with respect to bias
     caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
-        bias_multiplier_->cpu_data(), (Dtype)0.,
+        bias_multiplier_.cpu_data(), (Dtype)0.,
         this->blobs_[1]->mutable_cpu_diff());
   }
   if (propagate_down[0]) {
index 6ef7578..513e4eb 100644 (file)
@@ -21,7 +21,7 @@ Dtype InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       bottom_data, weight, (Dtype)0., top_data);
   if (bias_term_) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
-        bias_multiplier_->gpu_data(),
+        bias_multiplier_.gpu_data(),
         this->blobs_[1]->gpu_data(), (Dtype)1., top_data);
   }
   return Dtype(0);
@@ -42,7 +42,7 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->gpu_diff();
     // Gradient with respect to bias
     caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
-        bias_multiplier_->gpu_data(), (Dtype)0.,
+        bias_multiplier_.gpu_data(), (Dtype)0.,
         this->blobs_[1]->mutable_gpu_diff());
   }
   if (propagate_down[0]) {