inner product forward backward
authorYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 22:29:13 +0000 (15:29 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 22:29:13 +0000 (15:29 -0700)
src/caffeine/common.hpp
src/caffeine/layers/inner_product_layer.cu
src/caffeine/test/test_gradient_check_util.cpp
src/caffeine/test/test_innerproduct_layer.cpp

index 1bcd785..6721e26 100644 (file)
@@ -19,6 +19,8 @@
   template class classname<float>; \
   template class classname<double>
 
+#define NOT_IMPLEMENTED CHECK(false) << "Not Implemented"
+
 namespace caffeine {
 
 // We will use the boost shared_ptr instead of the new C++11 one mainly
index 26a7c4b..5b1124c 100644 (file)
@@ -67,7 +67,22 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
-  CHECK(false);
+  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+  // Gradient with respect to weight
+  caffeine_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+      bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
+  if (biasterm_) {
+    // Gradient with respect to bias
+    caffeine_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
+        (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0.,
+        this->blobs_[1].mutable_cpu_diff());
+  }
+  if (propagate_down) {
+    // Gradient with respect to bottom data
+    caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+        top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
+        (*bottom)[0]->mutable_cpu_diff());
+  }
   return Dtype(0);
 }
 
@@ -89,32 +104,12 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   const Dtype* weight = this->blobs_[0].gpu_data();
-  const Dtype* bias = NULL;
-  Dtype alpha = 1., beta = 0.;
+  caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
+      bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
-       bias = this->blobs_[1].gpu_data();
-       beta = 1.;
-       const int count = (*top)[0]->count();
-       // we pre-copy the bias to the results, and then call gemm.
-       BroadcastRow<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
-                       count, N_, bias, top_data);
-  }
-  switch(sizeof(Dtype)) {
-  case sizeof(float):
-    // matrix multiply: since cublas uses Fortran major, we actually do
-    // C' = B' A'
-    CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
-        CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
-        (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
-    break;
-  case sizeof(double):
-    // matrix multiply
-    CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
-        CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
-        (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
-    break;
-  default:
-    CHECK(false) << "Unknown data type.";
+    caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
+        (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(),
+        (Dtype)1., top_data);
   }
 }
 
@@ -122,8 +117,24 @@ template <typename Dtype>
 Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  CHECK(false);
-  return Dtype(0.);
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+  // Gradient with respect to weight
+  caffeine_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+      bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
+  if (biasterm_) {
+    // Gradient with respect to bias
+    caffeine_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
+        (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0.,
+        this->blobs_[1].mutable_gpu_diff());
+  }
+  if (propagate_down) {
+    // Gradient with respect to bottom data
+    caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+        top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
+        (*bottom)[0]->mutable_gpu_diff());
+  }
+  return Dtype(0);
 }
 
 INSTANTIATE_CLASS(InnerProductLayer);
index 94875db..d1748d2 100644 (file)
@@ -1,8 +1,11 @@
+#include <algorithm>
 #include <cmath>
 #include <glog/logging.h>
 #include <gtest/gtest.h>
 #include "caffeine/test/test_gradient_check_util.hpp"
 
+using std::max;
+
 namespace caffeine {
 
 template <typename Dtype>
@@ -58,8 +61,12 @@ void GradientChecker<Dtype>::CheckGradient(Layer<Dtype>& layer,
       //LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " "
       //    << current_blob->cpu_diff()[feat_id];
       if (kink_ - kink_range_ > feature || feature > kink_ + kink_range_) {
-        EXPECT_GT(computed_gradient, estimated_gradient - threshold_);
-        EXPECT_LT(computed_gradient, estimated_gradient + threshold_);
+        // We check relative accuracy, but for too small values, we threshold
+        // the scale factor by 1.
+        Dtype scale = max(max(fabs(computed_gradient), fabs(estimated_gradient)),
+            1.);
+        EXPECT_GT(computed_gradient, estimated_gradient - threshold_ * scale);
+        EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale);
       }
       //LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id];
       //LOG(ERROR) << "computed gradient: " << computed_gradient
index 1a43750..eb5e341 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffeine/common.hpp"
 #include "caffeine/filler.hpp"
 #include "caffeine/vision_layers.hpp"
+#include "caffeine/test/test_gradient_check_util.hpp"
 
 namespace caffeine {
 
@@ -88,4 +89,34 @@ TYPED_TEST(InnerProductLayerTest, TestGPU) {
        }
 }
 
+TYPED_TEST(InnerProductLayerTest, TestCPUGradient) {
+  LayerParameter layer_param;
+  Caffeine::set_mode(Caffeine::CPU);
+  layer_param.set_num_output(10);
+  layer_param.mutable_weight_filler()->set_type("uniform");
+  layer_param.mutable_bias_filler()->set_type("uniform");
+  layer_param.mutable_bias_filler()->set_min(1);
+  layer_param.mutable_bias_filler()->set_max(2);
+  InnerProductLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(InnerProductLayerTest, TestGPUGradient) {
+  if (sizeof(TypeParam) == 4 || CAFFEINE_TEST_CUDA_PROP.major >= 2) {
+    LayerParameter layer_param;
+    Caffeine::set_mode(Caffeine::GPU);
+    layer_param.set_num_output(10);
+    layer_param.mutable_weight_filler()->set_type("uniform");
+    layer_param.mutable_bias_filler()->set_type("uniform");
+    layer_param.mutable_bias_filler()->set_min(1);
+    layer_param.mutable_bias_filler()->set_max(2);
+    InnerProductLayer<TypeParam> layer(layer_param);
+    GradientChecker<TypeParam> checker(1e-2, 1e-2);
+    checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
+  } else {
+    LOG(ERROR) << "Skipping test due to old architecture.";
+  }
+}
+
 }