From: Jonathan L Long Date: Sun, 20 Jul 2014 01:02:33 +0000 (-0700) Subject: use Blob directly instead of shared_ptr for InnerProductLayer::bias_multiplier_ X-Git-Tag: submit/tizen/20180823.020014~653^2~54^2~5 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=40054a70cd6458e74ba8d42f80d3e534e9312a1d;p=platform%2Fupstream%2Fcaffeonacl.git use Blob directly instead of shared_ptr for InnerProductLayer::bias_multiplier_ --- diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index a261ac3..c995992 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -157,7 +157,7 @@ class InnerProductLayer : public Layer { int K_; int N_; bool bias_term_; - shared_ptr > bias_multiplier_; + Blob bias_multiplier_; }; // Forward declare PoolingLayer and SplitLayer for use in LRNLayer. diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 4da8578..5505d08 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -47,8 +47,8 @@ void InnerProductLayer::SetUp(const vector*>& bottom, } // parameter initialization // Setting up the bias multiplier if (bias_term_) { - bias_multiplier_.reset(new Blob(1, 1, 1, M_)); - caffe_set(M_, Dtype(1), bias_multiplier_->mutable_cpu_data()); + bias_multiplier_.Reshape(1, 1, 1, M_); + caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); } this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -63,7 +63,7 @@ Dtype InnerProductLayer::Forward_cpu(const vector*>& bottom, bottom_data, weight, (Dtype)0., top_data); if (bias_term_) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., - bias_multiplier_->cpu_data(), + bias_multiplier_.cpu_data(), this->blobs_[1]->cpu_data(), (Dtype)1., top_data); } return Dtype(0); @@ -84,7 +84,7 @@ void InnerProductLayer::Backward_cpu(const vector*>& top, const Dtype* top_diff = top[0]->cpu_diff(); // Gradient with respect to bias caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, - bias_multiplier_->cpu_data(), (Dtype)0., + bias_multiplier_.cpu_data(), (Dtype)0., this->blobs_[1]->mutable_cpu_diff()); } if (propagate_down[0]) { diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index 6ef7578..513e4eb 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -21,7 +21,7 @@ Dtype InnerProductLayer::Forward_gpu(const vector*>& bottom, bottom_data, weight, (Dtype)0., top_data); if (bias_term_) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., - bias_multiplier_->gpu_data(), + bias_multiplier_.gpu_data(), this->blobs_[1]->gpu_data(), (Dtype)1., top_data); } return Dtype(0); @@ -42,7 +42,7 @@ void InnerProductLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); // Gradient with respect to bias caffe_gpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, - bias_multiplier_->gpu_data(), (Dtype)0., + bias_multiplier_.gpu_data(), (Dtype)0., this->blobs_[1]->mutable_gpu_diff()); } if (propagate_down[0]) {