From 859da3d27c872f42e4edf8ecb32dd7db74483d43 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Wed, 18 Sep 2013 15:29:13 -0700 Subject: [PATCH] inner product forward backward --- src/caffeine/common.hpp | 2 + src/caffeine/layers/inner_product_layer.cu | 67 +++++++++++++++----------- src/caffeine/test/test_gradient_check_util.cpp | 11 ++++- src/caffeine/test/test_innerproduct_layer.cpp | 31 ++++++++++++ 4 files changed, 81 insertions(+), 30 deletions(-) diff --git a/src/caffeine/common.hpp b/src/caffeine/common.hpp index 1bcd785..6721e26 100644 --- a/src/caffeine/common.hpp +++ b/src/caffeine/common.hpp @@ -19,6 +19,8 @@ template class classname; \ template class classname +#define NOT_IMPLEMENTED CHECK(false) << "Not Implemented" + namespace caffeine { // We will use the boost shared_ptr instead of the new C++11 one mainly diff --git a/src/caffeine/layers/inner_product_layer.cu b/src/caffeine/layers/inner_product_layer.cu index 26a7c4b..5b1124c 100644 --- a/src/caffeine/layers/inner_product_layer.cu +++ b/src/caffeine/layers/inner_product_layer.cu @@ -67,7 +67,22 @@ Dtype InnerProductLayer::Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); - CHECK(false); + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + // Gradient with respect to weight + caffeine_cpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1., + bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff()); + if (biasterm_) { + // Gradient with respect to bias + caffeine_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, + (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0., + this->blobs_[1].mutable_cpu_diff()); + } + if (propagate_down) { + // Gradient with respect to bottom data + caffeine_cpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1., + top_diff, this->blobs_[0].cpu_data(), (Dtype)0., + (*bottom)[0]->mutable_cpu_diff()); + } return Dtype(0); } @@ -89,32 +104,12 @@ void InnerProductLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = (*top)[0]->mutable_gpu_data(); const Dtype* weight = this->blobs_[0].gpu_data(); - const Dtype* bias = NULL; - Dtype alpha = 1., beta = 0.; + caffeine_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1., + bottom_data, weight, (Dtype)0., top_data); if (biasterm_) { - bias = this->blobs_[1].gpu_data(); - beta = 1.; - const int count = (*top)[0]->count(); - // we pre-copy the bias to the results, and then call gemm. - BroadcastRow<<>>( - count, N_, bias, top_data); - } - switch(sizeof(Dtype)) { - case sizeof(float): - // matrix multiply: since cublas uses Fortran major, we actually do - // C' = B' A' - CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N, - CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_, - (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_)); - break; - case sizeof(double): - // matrix multiply - CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N, - CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_, - (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_)); - break; - default: - CHECK(false) << "Unknown data type."; + caffeine_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., + (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(), + (Dtype)1., top_data); } } @@ -122,8 +117,24 @@ template Dtype InnerProductLayer::Backward_gpu(const vector*>& top, const bool propagate_down, vector*>* bottom) { - CHECK(false); - return Dtype(0.); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = (*bottom)[0]->gpu_data(); + // Gradient with respect to weight + caffeine_gpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1., + bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff()); + if (biasterm_) { + // Gradient with respect to bias + caffeine_gpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, + (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0., + this->blobs_[1].mutable_gpu_diff()); + } + if (propagate_down) { + // Gradient with respect to bottom data + caffeine_gpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1., + top_diff, this->blobs_[0].gpu_data(), (Dtype)0., + (*bottom)[0]->mutable_gpu_diff()); + } + return Dtype(0); } INSTANTIATE_CLASS(InnerProductLayer); diff --git a/src/caffeine/test/test_gradient_check_util.cpp b/src/caffeine/test/test_gradient_check_util.cpp index 94875db..d1748d2 100644 --- a/src/caffeine/test/test_gradient_check_util.cpp +++ b/src/caffeine/test/test_gradient_check_util.cpp @@ -1,8 +1,11 @@ +#include #include #include #include #include "caffeine/test/test_gradient_check_util.hpp" +using std::max; + namespace caffeine { template @@ -58,8 +61,12 @@ void GradientChecker::CheckGradient(Layer& layer, //LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " // << current_blob->cpu_diff()[feat_id]; if (kink_ - kink_range_ > feature || feature > kink_ + kink_range_) { - EXPECT_GT(computed_gradient, estimated_gradient - threshold_); - EXPECT_LT(computed_gradient, estimated_gradient + threshold_); + // We check relative accuracy, but for too small values, we threshold + // the scale factor by 1. + Dtype scale = max(max(fabs(computed_gradient), fabs(estimated_gradient)), + 1.); + EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale); + EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale); } //LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]; //LOG(ERROR) << "computed gradient: " << computed_gradient diff --git a/src/caffeine/test/test_innerproduct_layer.cpp b/src/caffeine/test/test_innerproduct_layer.cpp index 1a43750..eb5e341 100644 --- a/src/caffeine/test/test_innerproduct_layer.cpp +++ b/src/caffeine/test/test_innerproduct_layer.cpp @@ -6,6 +6,7 @@ #include "caffeine/common.hpp" #include "caffeine/filler.hpp" #include "caffeine/vision_layers.hpp" +#include "caffeine/test/test_gradient_check_util.hpp" namespace caffeine { @@ -88,4 +89,34 @@ TYPED_TEST(InnerProductLayerTest, TestGPU) { } } +TYPED_TEST(InnerProductLayerTest, TestCPUGradient) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::CPU); + layer_param.set_num_output(10); + layer_param.mutable_weight_filler()->set_type("uniform"); + layer_param.mutable_bias_filler()->set_type("uniform"); + layer_param.mutable_bias_filler()->set_min(1); + layer_param.mutable_bias_filler()->set_max(2); + InnerProductLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + +TYPED_TEST(InnerProductLayerTest, TestGPUGradient) { + if (sizeof(TypeParam) == 4 || CAFFEINE_TEST_CUDA_PROP.major >= 2) { + LayerParameter layer_param; + Caffeine::set_mode(Caffeine::GPU); + layer_param.set_num_output(10); + layer_param.mutable_weight_filler()->set_type("uniform"); + layer_param.mutable_bias_filler()->set_type("uniform"); + layer_param.mutable_bias_filler()->set_min(1); + layer_param.mutable_bias_filler()->set_max(2); + InnerProductLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + } -- 2.7.4