misc update
authorYangqing Jia <jiayq84@gmail.com>
Wed, 9 Oct 2013 23:41:35 +0000 (16:41 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 9 Oct 2013 23:41:35 +0000 (16:41 -0700)
15 files changed:
src/caffe/common.hpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/pooling_layer.cu
src/caffe/layers/relu_layer.cu
src/caffe/net.cpp
src/caffe/net.hpp
src/caffe/optimization/solver.cpp
src/caffe/test/test_convolution_layer.cpp
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_innerproduct_layer.cpp
src/caffe/test/test_softmax_with_loss_layer.cpp
src/programs/demo_mnist.cpp
src/programs/dump_network.cpp
src/programs/train_alexnet.cpp

index 4e7e9ad..4b04070 100644 (file)
@@ -18,9 +18,9 @@
 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
 
 #define CUDA_POST_KERNEL_CHECK \
-  if (cudaSuccess != cudaPeekAtLastError()) {\
-    LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
-  }
+  if (cudaSuccess != cudaPeekAtLastError()) \
+    LOG(FATAL) << "Cuda kernel failed. Error: " \
+        << cudaGetErrorString(cudaPeekAtLastError())
 
 #define DISABLE_COPY_AND_ASSIGN(classname) \
 private:\
index 8bf913a..16c9362 100644 (file)
@@ -39,24 +39,32 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   K_ = CHANNELS_ * KSIZE_ * KSIZE_ / GROUP_;
   N_ = height_out * width_out;
   (*top)[0]->Reshape(bottom[0]->num(), NUM_OUTPUT_, height_out, width_out);
-  if (biasterm_) {
-    this->blobs_.resize(2);
+  // Check if we need to set up the weights
+  if (this->blobs_.size() > 0) {
+    LOG(INFO) << "Skipping parameter initialization";
   } else {
-    this->blobs_.resize(1);
+    if (biasterm_) {
+      this->blobs_.resize(2);
+    } else {
+      this->blobs_.resize(1);
+    }
+    // Intialize the weight
+    this->blobs_[0].reset(
+        new Blob<Dtype>(NUM_OUTPUT_, CHANNELS_ / GROUP_, KSIZE_, KSIZE_));
+    // fill the weights
+    shared_ptr<Filler<Dtype> > weight_filler(
+        GetFiller<Dtype>(this->layer_param_.weight_filler()));
+    weight_filler->Fill(this->blobs_[0].get());
+    // If necessary, intiialize and fill the bias term
+    if (biasterm_) {
+      this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, NUM_OUTPUT_));
+      shared_ptr<Filler<Dtype> > bias_filler(
+          GetFiller<Dtype>(this->layer_param_.bias_filler()));
+      bias_filler->Fill(this->blobs_[1].get());
+    }
   }
-  // Intialize the weight
-  this->blobs_[0].reset(
-      new Blob<Dtype>(NUM_OUTPUT_, CHANNELS_ / GROUP_, KSIZE_, KSIZE_));
-  // fill the weights
-  shared_ptr<Filler<Dtype> > weight_filler(
-      GetFiller<Dtype>(this->layer_param_.weight_filler()));
-  weight_filler->Fill(this->blobs_[0].get());
-  // If necessary, intiialize and fill the bias term
+  // Set up the bias filler
   if (biasterm_) {
-    this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, NUM_OUTPUT_));
-    shared_ptr<Filler<Dtype> > bias_filler(
-        GetFiller<Dtype>(this->layer_param_.bias_filler()));
-    bias_filler->Fill(this->blobs_[1].get());
     bias_multiplier_.reset(new SyncedMemory(N_ * sizeof(Dtype)));
     Dtype* bias_multiplier_data =
         reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
index e6f77ae..ef98593 100644 (file)
@@ -27,23 +27,31 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   K_ = bottom[0]->count() / bottom[0]->num();
   N_ = num_output;
   (*top)[0]->Reshape(bottom[0]->num(), num_output, 1, 1);
-  if (biasterm_) {
-    this->blobs_.resize(2);
+  // Check if we need to set up the weights
+  if (this->blobs_.size() > 0) {
+    LOG(INFO) << "Skipping parameter initialization";
   } else {
-    this->blobs_.resize(1);
-  }
-  // Intialize the weight
-  this->blobs_[0].reset(new Blob<Dtype>(1, 1, N_, K_));
-  // fill the weights
-  shared_ptr<Filler<Dtype> > weight_filler(
-      GetFiller<Dtype>(this->layer_param_.weight_filler()));
-  weight_filler->Fill(this->blobs_[0].get());
-  // If necessary, intiialize and fill the bias term
+    if (biasterm_) {
+      this->blobs_.resize(2);
+    } else {
+      this->blobs_.resize(1);
+    }
+    // Intialize the weight
+    this->blobs_[0].reset(new Blob<Dtype>(1, 1, N_, K_));
+    // fill the weights
+    shared_ptr<Filler<Dtype> > weight_filler(
+        GetFiller<Dtype>(this->layer_param_.weight_filler()));
+    weight_filler->Fill(this->blobs_[0].get());
+    // If necessary, intiialize and fill the bias term
+    if (biasterm_) {
+      this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, N_));
+      shared_ptr<Filler<Dtype> > bias_filler(
+          GetFiller<Dtype>(this->layer_param_.bias_filler()));
+      bias_filler->Fill(this->blobs_[1].get());
+    }
+  } // parameter initialization
+  // Setting up the bias multiplier
   if (biasterm_) {
-    this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, N_));
-    shared_ptr<Filler<Dtype> > bias_filler(
-        GetFiller<Dtype>(this->layer_param_.bias_filler()));
-    bias_filler->Fill(this->blobs_[1].get());
     bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
     Dtype* bias_multiplier_data =
         reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
index 43d1ab5..706ee15 100644 (file)
@@ -149,9 +149,6 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
         // figure out the pooling size
         int poolsize = (min(ph * stride + ksize, height) - ph * stride) *
             (min(pw * stride + ksize, width) - pw * stride);
-        if (poolsize <= 0) {
-          printf("error: %d %d %d %d %d\n", ph, pw, ksize, height, width);
-        }
         gradient += top_diff[ph * pooled_width + pw] / poolsize;
       }
     }
index 5e788c6..c386dd0 100644 (file)
@@ -51,6 +51,11 @@ void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const int count = bottom[0]->count();
   ReLUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
       count, bottom_data, top_data);
+  CUDA_POST_KERNEL_CHECK;
+  // << " count: " << count << " bottom_data: "
+  //     << (unsigned long)bottom_data << " top_data: " << (unsigned long)top_data
+  //     << " blocks: " << CAFFE_GET_BLOCKS(count)
+  //     << " threads: " << CAFFE_CUDA_NUM_THREADS;
 }
 
 template <typename Dtype>
@@ -73,6 +78,7 @@ Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const int count = (*bottom)[0]->count();
     ReLUBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, top_diff, bottom_data, bottom_diff);
+    CUDA_POST_KERNEL_CHECK;
   }
   return Dtype(0);
 }
index 6795ccc..edc036b 100644 (file)
@@ -111,7 +111,7 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
     blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
   }
   for (int i = 0; i < layers_.size(); ++i) {
-    // LOG(ERROR) << "Forwarding " << layer_names_[i];
+    //LOG(ERROR) << "Forwarding " << layer_names_[i];
     layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
   }
   return net_output_blobs_;
@@ -177,7 +177,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
       layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
     }
     LayerParameter* layer_parameter = layer_connection->mutable_layer();
-    layers_[i]->ToProto(layer_parameter);
+    layers_[i]->ToProto(layer_parameter, write_diff);
   }
 }
 
index 4b24c23..799bc1a 100644 (file)
@@ -51,6 +51,10 @@ class Net {
   inline const vector<shared_ptr<Blob<Dtype> > >& blobs() { return blobs_; }
   // returns the layers
   inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
+  // returns the bottom and top vecs for each layer - usually you won't need
+  // this unless you do per-layer checks such as gradients.
+  inline vector<vector<Blob<Dtype>*> >& bottom_vecs() { return bottom_vecs_; }
+  inline vector<vector<Blob<Dtype>*> >& top_vecs() { return top_vecs_; }
   // returns the parameters
   vector<shared_ptr<Blob<Dtype> > >& params() { return params_; };
   // Updates the network
index 6d82df3..3459cc4 100644 (file)
@@ -99,6 +99,8 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   Dtype rate = GetLearningRate();
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
+  LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
+      << " weight_decay:" << weight_decay;
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
index fdd870f..6397ae1 100644 (file)
@@ -145,9 +145,11 @@ TYPED_TEST(ConvolutionLayerTest, TestCPUGradient) {
   layer_param.set_kernelsize(3);
   layer_param.set_stride(2);
   layer_param.set_num_output(2);
+  layer_param.mutable_weight_filler()->set_type("gaussian");
+  layer_param.mutable_bias_filler()->set_type("gaussian");
   Caffe::set_mode(Caffe::CPU);
   ConvolutionLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
@@ -157,9 +159,11 @@ TYPED_TEST(ConvolutionLayerTest, TestCPUGradientGroup) {
   layer_param.set_stride(2);
   layer_param.set_num_output(3);
   layer_param.set_group(3);
+  layer_param.mutable_weight_filler()->set_type("gaussian");
+  layer_param.mutable_bias_filler()->set_type("gaussian");
   Caffe::set_mode(Caffe::CPU);
   ConvolutionLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
@@ -168,9 +172,11 @@ TYPED_TEST(ConvolutionLayerTest, TestGPUGradient) {
   layer_param.set_kernelsize(3);
   layer_param.set_stride(2);
   layer_param.set_num_output(2);
+  layer_param.mutable_weight_filler()->set_type("gaussian");
+  layer_param.mutable_bias_filler()->set_type("gaussian");
   Caffe::set_mode(Caffe::GPU);
   ConvolutionLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
@@ -180,9 +186,11 @@ TYPED_TEST(ConvolutionLayerTest, TestGPUGradientGroup) {
   layer_param.set_stride(2);
   layer_param.set_num_output(3);
   layer_param.set_group(3);
+  layer_param.mutable_weight_filler()->set_type("gaussian");
+  layer_param.mutable_bias_filler()->set_type("gaussian");
   Caffe::set_mode(Caffe::GPU);
   ConvolutionLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
index 55a5b95..4955394 100644 (file)
@@ -11,6 +11,7 @@
 #include <vector>
 
 #include "caffe/layer.hpp"
+#include "caffe/net.hpp"
 
 using std::max;
 
@@ -31,7 +32,7 @@ class GradientChecker {
   // the parameters of the layer, as well as the input blobs if check_through
   // is set True.
   // Note that after the gradient check, we do not guarantee that the data
-  // stored in the layer parameters and the blobs.
+  // stored in the layer parameters and the blobs are unchanged.
   void CheckGradient(Layer<Dtype>& layer, vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>& top, int check_bottom = -1) {
       layer.SetUp(bottom, &top);
@@ -45,6 +46,12 @@ class GradientChecker {
       vector<Blob<Dtype>*>& top, int check_bottom, int top_id,
       int top_data_id);
 
+  // Checks the gradient of a network. This network should not have any data
+  // layers or loss layers, since the function does not explicitly deal with
+  // such cases yet. All input blobs and parameter blobs are going to be
+  // checked, layer-by-layer to avoid numerical problems to accumulate.
+  void CheckGradientNet(Net<Dtype>& net, vector<Blob<Dtype>*>& input);
+
  protected:
   Dtype GetObjAndGradient(vector<Blob<Dtype>*>& top, int top_id = -1,
       int top_data_id = -1);
@@ -124,7 +131,7 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
       }
       // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id];
       // LOG(ERROR) << "computed gradient: " << computed_gradient
-      //     << " estimated_gradient: " << estimated_gradient;
+      //    << " estimated_gradient: " << estimated_gradient;
     }
   }
 }
@@ -136,7 +143,7 @@ void GradientChecker<Dtype>::CheckGradientExhaustive(Layer<Dtype>& layer,
   layer.SetUp(bottom, &top);
   // LOG(ERROR) << "Exhaustive Mode.";
   for (int i = 0; i < top.size(); ++i) {
-    // LOG(ERROR) << "Exhaustive: blob " << i << " size " << top[i]->count();
+    LOG(ERROR) << "Exhaustive: blob " << i << " size " << top[i]->count();
     for (int j = 0; j < top[i]->count(); ++j) {
       // LOG(ERROR) << "Exhaustive: blob " << i << " data " << j;
       CheckGradientSingle(layer, bottom, top, check_bottom, i, j);
@@ -145,6 +152,19 @@ void GradientChecker<Dtype>::CheckGradientExhaustive(Layer<Dtype>& layer,
 }
 
 template <typename Dtype>
+void GradientChecker<Dtype>::CheckGradientNet(
+    Net<Dtype>& net, vector<Blob<Dtype>*>& input) {
+  const vector<shared_ptr<Layer<Dtype> > >& layers = net.layers();
+  vector<vector<Blob<Dtype>*> >& bottom_vecs = net.bottom_vecs();
+  vector<vector<Blob<Dtype>*> >& top_vecs = net.top_vecs();
+  for (int i = 0; i < layers.size(); ++i) {
+    net.Forward(input);
+    LOG(ERROR) << "Checking gradient for " << layers[i]->layer_param().name();
+    CheckGradientExhaustive(*(layers[i].get()), bottom_vecs[i], top_vecs[i]);
+  }
+}
+
+template <typename Dtype>
 Dtype GradientChecker<Dtype>::GetObjAndGradient(vector<Blob<Dtype>*>& top,
     int top_id, int top_data_id) {
   Dtype loss = 0;
index 3ccd34e..0e2b612 100644 (file)
@@ -97,12 +97,12 @@ TYPED_TEST(InnerProductLayerTest, TestCPUGradient) {
   LayerParameter layer_param;
   Caffe::set_mode(Caffe::CPU);
   layer_param.set_num_output(10);
-  layer_param.mutable_weight_filler()->set_type("uniform");
-  layer_param.mutable_bias_filler()->set_type("uniform");
+  layer_param.mutable_weight_filler()->set_type("gaussian");
+  layer_param.mutable_bias_filler()->set_type("gaussian");
   layer_param.mutable_bias_filler()->set_min(1);
   layer_param.mutable_bias_filler()->set_max(2);
   InnerProductLayer<TypeParam> layer(layer_param);
-  GradientChecker<TypeParam> checker(1e-2, 1e-2);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
   checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
 }
 
@@ -111,10 +111,8 @@ TYPED_TEST(InnerProductLayerTest, TestGPUGradient) {
     LayerParameter layer_param;
     Caffe::set_mode(Caffe::GPU);
     layer_param.set_num_output(10);
-    layer_param.mutable_weight_filler()->set_type("uniform");
-    layer_param.mutable_bias_filler()->set_type("uniform");
-    layer_param.mutable_bias_filler()->set_min(1);
-    layer_param.mutable_bias_filler()->set_max(2);
+    layer_param.mutable_weight_filler()->set_type("gaussian");
+    layer_param.mutable_bias_filler()->set_type("gaussian");
     InnerProductLayer<TypeParam> layer(layer_param);
     GradientChecker<TypeParam> checker(1e-2, 1e-2);
     checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_);
index a955192..328f64b 100644 (file)
@@ -26,6 +26,7 @@ class SoftmaxWithLossLayerTest : public ::testing::Test {
         blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
     // fill the values
     FillerParameter filler_param;
+    filler_param.set_std(10);
     GaussianFiller<Dtype> filler(filler_param);
     filler.Fill(this->blob_bottom_data_);
     blob_bottom_vec_.push_back(blob_bottom_data_);
index 6d15d75..f442fe6 100644 (file)
@@ -17,7 +17,7 @@
 using namespace caffe;
 
 int main(int argc, char** argv) {
-  cudaSetDevice(1);
+  cudaSetDevice(0);
   Caffe::set_mode(Caffe::GPU);
 
   NetParameter net_param;
index 2c05297..3507100 100644 (file)
@@ -24,7 +24,7 @@ using namespace caffe;
 
 int main(int argc, char** argv) {
   cudaSetDevice(1);
-  Caffe::set_mode(Caffe::CPU);
+  Caffe::set_mode(Caffe::GPU);
   Caffe::set_phase(Caffe::TEST);
 
   NetParameter net_param;
index 2063efd..3fc2139 100644 (file)
@@ -17,7 +17,7 @@
 using namespace caffe;
 
 int main(int argc, char** argv) {
-  cudaSetDevice(1);
+  cudaSetDevice(0);
   Caffe::set_mode(Caffe::GPU);
 
   NetParameter net_param;
@@ -32,14 +32,28 @@ int main(int argc, char** argv) {
   LOG(ERROR) << "Performing Backward";
   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
 
+  /*
+  // Now, let's dump all the layers
+  string output_prefix("alexnet_initial_dump_");
+  const vector<string>& blob_names = caffe_net.blob_names();
+  const vector<shared_ptr<Blob<float> > >& blobs = caffe_net.blobs();
+  for (int blobid = 0; blobid < caffe_net.blobs().size(); ++blobid) {
+    // Serialize blob
+    LOG(ERROR) << "Dumping " << blob_names[blobid];
+    BlobProto output_blob_proto;
+    blobs[blobid]->ToProto(&output_blob_proto);
+    WriteProtoToBinaryFile(output_blob_proto, output_prefix + blob_names[blobid]);
+  }
+  */
+
   SolverParameter solver_param;
-  solver_param.set_base_lr(0.001);
+  solver_param.set_base_lr(0.01);
   solver_param.set_display(1);
-  solver_param.set_max_iter(60000);
+  solver_param.set_max_iter(2);
   solver_param.set_lr_policy("fixed");
   solver_param.set_momentum(0.9);
   solver_param.set_weight_decay(0.0005);
-  solver_param.set_snapshot(1000);
+  solver_param.set_snapshot(1);
   solver_param.set_snapshot_prefix("alexnet");
 
   LOG(ERROR) << "Starting Optimization";