a lot of modifications - disallow copy constructors and misc
authorYangqing Jia <jiayq84@gmail.com>
Thu, 26 Sep 2013 23:08:40 +0000 (16:08 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 26 Sep 2013 23:08:40 +0000 (16:08 -0700)
16 files changed:
src/caffe/blob.cpp
src/caffe/blob.hpp
src/caffe/common.hpp
src/caffe/layer.hpp
src/caffe/layer_factory.hpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/net.cpp
src/caffe/net.hpp
src/caffe/proto/lenet.prototxt [deleted file]
src/caffe/test/lenet.hpp [new file with mode: 0644]
src/caffe/test/test_blob.cpp
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_innerproduct_layer.cpp
src/caffe/test/test_net_proto.cpp [new file with mode: 0644]
src/caffe/test/test_pooling_layer.cpp

index ecb37b7..0a00ce5 100644 (file)
@@ -11,17 +11,22 @@ namespace caffe {
 template <typename Dtype>
 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
     const int width) {
-  CHECK_GT(num, 0);
-  CHECK_GT(channels, 0);
-  CHECK_GT(height, 0);
-  CHECK_GT(width, 0);
+  CHECK_GE(num, 0);
+  CHECK_GE(channels, 0);
+  CHECK_GE(height, 0);
+  CHECK_GE(width, 0);
   num_ = num;
   channels_ = channels;
   height_ = height;
   width_ = width;
   count_ = num_ * channels_ * height_ * width_;
-  data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
-  diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+  if (count_) {
+    data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+  } else {
+    data_.reset((SyncedMemory*)NULL);
+    diff_.reset((SyncedMemory*)NULL);
+  }
 }
 
 template <typename Dtype>
@@ -31,37 +36,6 @@ Blob<Dtype>::Blob(const int num, const int channels, const int height,
 }
 
 template <typename Dtype>
-Blob<Dtype>::Blob(const Blob<Dtype>& source) {
-  if (source.count() == 0) {
-    Blob();
-  } else {
-    Reshape(source.num(), source.channels(), source.height(),
-        source.width());
-    if (count_ > 0) {
-      // Copy the data.
-      memcpy(data_->mutable_cpu_data(), source.cpu_data(),
-          count_ * sizeof(Dtype));
-      memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
-          count_ * sizeof(Dtype));
-    }
-  }
-}
-
-template <typename Dtype>
-const Blob<Dtype>& Blob<Dtype>::operator=(const Blob<Dtype>& source) {
-  Reshape(source.num(), source.channels(), source.height(),
-        source.width());
-  if (count_ > 0) {
-    // Copy the data.
-    memcpy(data_->mutable_cpu_data(), source.cpu_data(),
-        count_ * sizeof(Dtype));
-    memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
-        count_ * sizeof(Dtype));
-  }
-  return (*this);
-}
-
-template <typename Dtype>
 const Dtype* Blob<Dtype>::cpu_data() const {
   CHECK(data_);
   return (const Dtype*)data_->cpu_data();
index a14c046..39a2cf0 100644 (file)
@@ -17,8 +17,6 @@ class Blob {
        diff_() {}
   explicit Blob(const int num, const int channels, const int height,
     const int width);
-  Blob(const Blob<Dtype>& source);
-  const Blob<Dtype>& operator=(const Blob<Dtype>& src);
   virtual ~Blob() {}
   void Reshape(const int num, const int height,
       const int width, const int channels);
@@ -62,6 +60,8 @@ class Blob {
   int height_;
   int width_;
   int count_;
+
+  DISABLE_COPY_AND_ASSIGN(Blob);
 };  // class Blob
 
 }  // namespace caffe
index 67177a6..18c5b41 100644 (file)
     LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
   }
 
+#define DISABLE_COPY_AND_ASSIGN(classname) \
+ private:\
+  classname(const classname&);\
+  classname& operator=(const classname&)
+
 #define INSTANTIATE_CLASS(classname) \
   template class classname<float>; \
   template class classname<double>
index 551ed3b..b82f038 100644 (file)
@@ -35,7 +35,7 @@ class Layer {
       vector<Blob<Dtype>*>* bottom);
 
   // Returns the vector of parameters.
-  vector<Blob<Dtype> >& params() {
+  vector<shared_ptr<Blob<Dtype> > >& params() {
     return blobs_;
   }
 
@@ -46,7 +46,7 @@ class Layer {
   // The protobuf that stores the layer parameters
   LayerParameter layer_param_;
   // The vector that stores the parameters as a set of blobs.
-  vector<Blob<Dtype> > blobs_;
+  vector<shared_ptr<Blob<Dtype> > > blobs_;
 
   // Forward functions
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
@@ -70,6 +70,8 @@ class Layer {
     LOG(WARNING) << "Using CPU code as backup.";
     return Backward_cpu(top, propagate_down, bottom);
   };
+
+  DISABLE_COPY_AND_ASSIGN(Layer);
 };  // class Layer
 
 // Forward and backward wrappers. You should implement the cpu and
@@ -110,7 +112,7 @@ void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
   param->CopyFrom(layer_param_);
   param->clear_blobs();
   for (int i = 0; i < blobs_.size(); ++i) {
-    blobs_[i].ToProto(param->add_blobs(), write_diff);
+    blobs_[i]->ToProto(param->add_blobs(), write_diff);
   }
 }
 
index 90e6d66..8453fd5 100644 (file)
@@ -37,6 +37,10 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new PoolingLayer<Dtype>(param);
   } else if (type == "relu") {
     return new ReLULayer<Dtype>(param);
+  } else if (type == "softmax") {
+    return new SoftmaxLayer<Dtype>(param);
+  } else if (type == "multinomial_logistic_loss") {
+    return new MultinomialLogisticLossLayer<Dtype>(param);
   } else {
     LOG(FATAL) << "Unknown filler name: " << type;
   }
index 849e106..9560e47 100644 (file)
@@ -13,8 +13,8 @@ namespace caffe {
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 1) << "Im2col Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output.";
+  CHECK_EQ(bottom.size(), 1) << "Conv Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Conv Layer takes a single blob as output.";
   KSIZE_ = this->layer_param_.kernelsize();
   STRIDE_ = this->layer_param_.stride();
   GROUP_ = this->layer_param_.group();
@@ -23,6 +23,7 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   HEIGHT_ = bottom[0]->height();
   WIDTH_ = bottom[0]->width();
   NUM_OUTPUT_ = this->layer_param_.num_output();
+  CHECK_GT(NUM_OUTPUT_, 0);
   CHECK_EQ(CHANNELS_ % GROUP_, 0);
   // The im2col result buffer would only hold one image at a time to avoid
   // overly large memory usage.
@@ -44,17 +45,17 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
     this->blobs_.resize(1);
   }
   // Intialize the weight
-  this->blobs_[0].Reshape(1, 1, NUM_OUTPUT_, K_);
+  this->blobs_[0].reset(new Blob<Dtype>(1, 1, NUM_OUTPUT_, K_));
   // fill the weights
   shared_ptr<Filler<Dtype> > weight_filler(
       GetFiller<Dtype>(this->layer_param_.weight_filler()));
-  weight_filler->Fill(&this->blobs_[0]);
+  weight_filler->Fill(this->blobs_[0].get());
   // If necessary, intiialize and fill the bias term
   if (biasterm_) {
-    this->blobs_[1].Reshape(1, 1, 1, NUM_OUTPUT_);
+    this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, NUM_OUTPUT_));
     shared_ptr<Filler<Dtype> > bias_filler(
         GetFiller<Dtype>(this->layer_param_.bias_filler()));
-    bias_filler->Fill(&this->blobs_[1]);
+    bias_filler->Fill(this->blobs_[1].get());
     bias_multiplier_.reset(new SyncedMemory(N_ * sizeof(Dtype)));
     Dtype* bias_multiplier_data =
         reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
@@ -71,7 +72,7 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
   Dtype* col_data = col_buffer_.mutable_cpu_data();
-  const Dtype* weight = this->blobs_[0].cpu_data();
+  const Dtype* weight = this->blobs_[0]->cpu_data();
   int weight_offset = M_ * K_;
   int col_offset = K_ * N_;
   int top_offset = M_ * N_;
@@ -88,7 +89,7 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     // third, add bias
     if (biasterm_) {
       caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
-          N_, 1, (Dtype)1., this->blobs_[1].cpu_data(),
+          N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(),
           reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
           (Dtype)1., top_data + (*top)[0]->offset(n));
     }
@@ -101,7 +102,7 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   Dtype* col_data = col_buffer_.mutable_gpu_data();
-  const Dtype* weight = this->blobs_[0].gpu_data();
+  const Dtype* weight = this->blobs_[0]->gpu_data();
   int weight_offset = M_ * K_;
   int col_offset = K_ * N_;
   int top_offset = M_ * N_;
@@ -118,7 +119,7 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     // third, add bias
     if (biasterm_) {
       caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
-          N_, 1, (Dtype)1., this->blobs_[1].gpu_data(),
+          N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
           reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
           (Dtype)1., top_data + (*top)[0]->offset(n));
     }
@@ -129,8 +130,8 @@ template <typename Dtype>
 Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
-  const Dtype* weight = this->blobs_[0].cpu_data();
-  Dtype* weight_diff = this->blobs_[0].mutable_cpu_diff();
+  const Dtype* weight = this->blobs_[0]->cpu_data();
+  Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
   Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
   Dtype* col_data = col_buffer_.mutable_cpu_data();
@@ -139,8 +140,8 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   Dtype* bias_diff = NULL;
 
   if (biasterm_) {
-    bias_diff = this->blobs_[1].mutable_cpu_diff();
-    memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1].count());
+    bias_diff = this->blobs_[1]->mutable_cpu_diff();
+    memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1]->count());
     for (int n = 0; n < NUM_; ++n) {
       caffe_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
           1., top_diff + top[0]->offset(n),
@@ -152,7 +153,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   int weight_offset = M_ * K_;
   int col_offset = K_ * N_;
   int top_offset = M_ * N_;
-  memset(weight_diff, 0., sizeof(Dtype) * this->blobs_[0].count());
+  memset(weight_diff, 0., sizeof(Dtype) * this->blobs_[0]->count());
   for (int n = 0; n < NUM_; ++n) {
     // since we saved memory in the forward pass by not storing all col data,
     // we will need to recompute them.
@@ -185,8 +186,8 @@ template <typename Dtype>
 Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
-  const Dtype* weight = this->blobs_[0].gpu_data();
-  Dtype* weight_diff = this->blobs_[0].mutable_gpu_diff();
+  const Dtype* weight = this->blobs_[0]->gpu_data();
+  Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
   Dtype* col_data = col_buffer_.mutable_gpu_data();
@@ -195,9 +196,9 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   Dtype* bias_diff = NULL;
 
   if (biasterm_) {
-    bias_diff = this->blobs_[1].mutable_gpu_diff();
+    bias_diff = this->blobs_[1]->mutable_gpu_diff();
     CUDA_CHECK(cudaMemset(bias_diff, 0.,
-        sizeof(Dtype) * this->blobs_[1].count()));
+        sizeof(Dtype) * this->blobs_[1]->count()));
     for (int n = 0; n < NUM_; ++n) {
       caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
           1., top_diff + top[0]->offset(n),
@@ -210,7 +211,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   int col_offset = K_ * N_;
   int top_offset = M_ * N_;
   CUDA_CHECK(cudaMemset(weight_diff, 0.,
-      sizeof(Dtype) * this->blobs_[0].count()));
+      sizeof(Dtype) * this->blobs_[0]->count()));
   for (int n = 0; n < NUM_; ++n) {
     // since we saved memory in the forward pass by not storing all col data,
     // we will need to recompute them.
index b39b568..55b51c6 100644 (file)
@@ -32,17 +32,17 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
     this->blobs_.resize(1);
   }
   // Intialize the weight
-  this->blobs_[0].Reshape(1, 1, N_, K_);
+  this->blobs_[0].reset(new Blob<Dtype>(1, 1, N_, K_));
   // fill the weights
   shared_ptr<Filler<Dtype> > weight_filler(
       GetFiller<Dtype>(this->layer_param_.weight_filler()));
-  weight_filler->Fill(&this->blobs_[0]);
+  weight_filler->Fill(this->blobs_[0].get());
   // If necessary, intiialize and fill the bias term
   if (biasterm_) {
-    this->blobs_[1].Reshape(1, 1, 1, N_);
+    this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, N_));
     shared_ptr<Filler<Dtype> > bias_filler(
         GetFiller<Dtype>(this->layer_param_.bias_filler()));
-    bias_filler->Fill(&this->blobs_[1]);
+    bias_filler->Fill(this->blobs_[1].get());
     bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
     Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
     for (int i = 0; i < M_; ++i) {
@@ -56,13 +56,13 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  const Dtype* weight = this->blobs_[0].cpu_data();
+  const Dtype* weight = this->blobs_[0]->cpu_data();
   caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
         reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
-        this->blobs_[1].cpu_data(), (Dtype)1., top_data);
+        this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
   }
 }
 
@@ -74,17 +74,17 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
   // Gradient with respect to weight
   caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
-      bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
+      bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
         reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
-        this->blobs_[1].mutable_cpu_diff());
+        this->blobs_[1]->mutable_cpu_diff());
   }
   if (propagate_down) {
     // Gradient with respect to bottom data
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
-        top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
+        top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
         (*bottom)[0]->mutable_cpu_diff());
   }
   return Dtype(0);
@@ -95,13 +95,13 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
-  const Dtype* weight = this->blobs_[0].gpu_data();
+  const Dtype* weight = this->blobs_[0]->gpu_data();
   caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
       bottom_data, weight, (Dtype)0., top_data);
   if (biasterm_) {
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
         reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
-        this->blobs_[1].gpu_data(), (Dtype)1., top_data);
+        this->blobs_[1]->gpu_data(), (Dtype)1., top_data);
   }
 }
 
@@ -113,17 +113,17 @@ Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
   // Gradient with respect to weight
   caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
-      bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
+      bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
   if (biasterm_) {
     // Gradient with respect to bias
     caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
         reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
-        (Dtype)0., this->blobs_[1].mutable_gpu_diff());
+        (Dtype)0., this->blobs_[1]->mutable_gpu_diff());
   }
   if (propagate_down) {
     // Gradient with respect to bottom data
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
-        top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
+        top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
         (*bottom)[0]->mutable_gpu_diff());
   }
   return Dtype(0);
index 93969b1..75c9043 100644 (file)
@@ -28,20 +28,25 @@ Net<Dtype>::Net(const NetParameter& param,
   // set the input blobs
   for (int i = 0; i < param.bottom_size(); ++i) {
     const string& blob_name = param.bottom(i);
-    blobs_.push_back(Blob<Dtype>(*bottom[i]));
+    CHECK_GT(bottom[i]->count(), 0);
+    shared_ptr<Blob<Dtype> > blob_pointer(
+        new Blob<Dtype>(bottom[i]->num(), bottom[i]->channels(),
+            bottom[i]->height(), bottom[i]->width()));
+    blobs_.push_back(blob_pointer);
     blob_names_.push_back(blob_name);
     net_input_blob_indices_.push_back(i);
     blob_name_to_idx[blob_name] = i;
     available_blobs.insert(blob_name);
   }
   // For each layer, set up their input and output
-  layers_.resize(param.layers_size());
   bottom_vecs_.resize(param.layers_size());
   top_vecs_.resize(param.layers_size());
-  for (int i = 0; i < param.top_size(); ++i) {
+  for (int i = 0; i < param.layers_size(); ++i) {
     const LayerConnection& layer_connection = param.layers(i);
     const LayerParameter& layer_param = layer_connection.layer();
-    layers_[i].reset(GetLayer<Dtype>(layer_param));
+    layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
+    layer_names_.push_back(layer_param.name());
+    LOG(INFO) << "Creating Layer " << layer_param.name();
     // Figure out this layer's input and output
     for (int j = 0; j < layer_connection.bottom_size(); ++j) {
       const string& blob_name = layer_connection.bottom(j);
@@ -49,8 +54,9 @@ Net<Dtype>::Net(const NetParameter& param,
         LOG(FATAL) << "Unknown blob input " << blob_name <<
             " to layer" << j;
       }
+      LOG(INFO) << layer_param.name() << " <- " << blob_name;
       bottom_vecs_[i].push_back(
-          &blobs_[blob_name_to_idx[blob_name]]);
+          blobs_[blob_name_to_idx[blob_name]].get());
       available_blobs.erase(blob_name);
     }
     for (int j = 0; j < layer_connection.top_size(); ++j) {
@@ -58,18 +64,21 @@ Net<Dtype>::Net(const NetParameter& param,
       if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) {
         LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
       }
-      blobs_.push_back(Blob<Dtype>());
+      LOG(INFO) << layer_param.name() << " -> " << blob_name;
+      shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
+      blobs_.push_back(blob_pointer);
       blob_names_.push_back(blob_name);
       blob_name_to_idx[blob_name] = blob_names_.size() - 1;
       available_blobs.insert(blob_name);
-      top_vecs_[i].push_back(&blobs_[blob_names_.size() - 1]);
+      top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
     }
   }
+  LOG(INFO) << "Checking top blobs.";
   // In the end, check if all remaining available blobs are top blobs.
   for (int i = 0; i < param.top_size(); ++i) {
     const string& blob_name = param.top(i);
     if (blob_name_to_idx.find(blob_name) == blob_name_to_idx.end()) {
-      LOG(FATAL) << "Unknown blob input " << blob_name;
+      LOG(FATAL) << "Unknown blob output " << blob_name;
     }
     net_output_blob_indices_.push_back(blob_name_to_idx[blob_name]);
     available_blobs.erase(blob_name);
@@ -84,8 +93,10 @@ Net<Dtype>::Net(const NetParameter& param,
 
   LOG(INFO) << "Setting up the layers.";
   for (int i = 0; i < layers_.size(); ++i) {
+    LOG(INFO) << "Setting up " << layer_names_[i];
     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
   }
+  LOG(INFO) << "Network initialization done.";
 }
 
 template <typename Dtype>
@@ -93,7 +104,8 @@ void Net<Dtype>::Forward(const vector<Blob<Dtype>*> & bottom,
     vector<Blob<Dtype>*>* top) {
   // Copy bottom to internal bottom
   for (int i = 0; i < bottom.size(); ++i) {
-    blobs_[net_input_blob_indices_[i]] = *bottom[i];
+    memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(),
+        bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count());
   }
   for (int i = 0; i < layers_.size(); ++i) {
     layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
@@ -130,11 +142,12 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
       continue;
     }
     LOG(INFO) << "Loading source layer " << source_layer_name;
-    vector<Blob<Dtype> >& target_blobs = layers_[target_layer_id]->params();
+    vector<shared_ptr<Blob<Dtype> > >& target_blobs =
+        layers_[target_layer_id]->params();
     CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
         << "Incompatible number of blobs for layer " << source_layer_name;
     for (int j = 0; j < target_blobs.size(); ++j) {
-      target_blobs[j].FromProto(source_layer.blobs(j));
+      target_blobs[j]->FromProto(source_layer.blobs(j));
     }
   }
 }
index debb59e..45ea708 100644 (file)
@@ -8,6 +8,7 @@
 #include <vector>
 
 #include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
 
@@ -38,6 +39,10 @@ class Net {
 
   // returns the network name.
   const string& name() { return name_; }
+  // returns the layer names
+  const vector<string>& layer_names() { return layer_names_; }
+  // returns the blob names
+  const vector<string>& blob_names() { return blob_names_; }
 
  protected:
   // Individual layers in the net
@@ -45,7 +50,7 @@ class Net {
   vector<string> layer_names_;
   // blobs stores the blobs that store intermediate results between the
   // layers.
-  vector<Blob<Dtype> > blobs_;
+  vector<shared_ptr<Blob<Dtype> > > blobs_;
   vector<string> blob_names_;
   // bottom_vecs stores the vectors containing the input for each layer, except
   // for the first layer whose bottom vec is provided by the network's input.
diff --git a/src/caffe/proto/lenet.prototxt b/src/caffe/proto/lenet.prototxt
deleted file mode 100644 (file)
index b4cb31b..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-name: "LeNet"
-bottom: "data"
-bottom: "label"
-layers {
-  layer {
-    name: "conv1"
-    type: "conv"
-    num_output: 20
-    kernelsize: 5
-    stride: 1
-  }
-  bottom: "data"
-  top: "conv1"
-}
-layers {
-  layer {
-    name: "pool1"
-    type: "pool"
-    kernelsize: 2
-    stride: 2
-    pool: MAX
-  }
-  bottom: "conv1"
-  top: "pool1"
-}
-layers {
-  layer {
-    name: "conv2"
-    type: "conv"
-    num_output: 50
-    kernelsize: 5
-    stride: 1
-  }
-  bottom: "pool1"
-  top: "conv2"
-}
-layers {
-  layer {
-    name: "pool2"
-    type: "pool"
-    kernelsize: 2
-    stride: 2
-    pool: MAX
-  }
-  bottom: "conv2"
-  top: "pool2"
-}
-layers {
-  layer {
-    name: "ip1"
-    type: "innerproduct"
-    num_output: 500
-  }
-  bottom: "pool2"
-  top: "ip1"
-}
-layers {
-  layer {
-    name: "relu1"
-    type: "relu"
-  }
-  bottom: "ip1"
-  top: "relu1"
-}
-layers {
-  layer {
-    name: "ip2"
-    type: "innerproduct"
-    num_output: 10
-  }
-  bottom: "relu1"
-  top: "ip2"
-}
-layers {
-  layer {
-    name: "prob"
-    type: "softmax"
-  }
-  bottom: "ip2"
-  top: "prob"
-}
-layers {
-  layer {
-    name: "loss"
-    type: "softmaxloss"
-  }
-  bottom: "prob"
-  bottom: "label"
-}
\ No newline at end of file
diff --git a/src/caffe/test/lenet.hpp b/src/caffe/test/lenet.hpp
new file mode 100644 (file)
index 0000000..29ec3c4
--- /dev/null
@@ -0,0 +1,100 @@
+#ifndef CAFFE_TEST_LENET_HPP_
+#define CAFFE_TEST_LENET_HPP_
+
+#include <string>
+
+namespace caffe {
+
+const char* kLENET = "name: \"LeNet\"\n\
+bottom: \"data\"\n\
+bottom: \"label\"\n\
+layers {\n\
+  layer {\n\
+    name: \"conv1\"\n\
+    type: \"conv\"\n\
+    num_output: 20\n\
+    kernelsize: 5\n\
+    stride: 1\n\
+  }\n\
+  bottom: \"data\"\n\
+  top: \"conv1\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"pool1\"\n\
+    type: \"pool\"\n\
+    kernelsize: 2\n\
+    stride: 2\n\
+    pool: MAX\n\
+  }\n\
+  bottom: \"conv1\"\n\
+  top: \"pool1\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"conv2\"\n\
+    type: \"conv\"\n\
+    num_output: 50\n\
+    kernelsize: 5\n\
+    stride: 1\n\
+  }\n\
+  bottom: \"pool1\"\n\
+  top: \"conv2\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"pool2\"\n\
+    type: \"pool\"\n\
+    kernelsize: 2\n\
+    stride: 2\n\
+    pool: MAX\n\
+  }\n\
+  bottom: \"conv2\"\n\
+  top: \"pool2\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"ip1\"\n\
+    type: \"innerproduct\"\n\
+    num_output: 500\n\
+  }\n\
+  bottom: \"pool2\"\n\
+  top: \"ip1\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"relu1\"\n\
+    type: \"relu\"\n\
+  }\n\
+  bottom: \"ip1\"\n\
+  top: \"relu1\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"ip2\"\n\
+    type: \"innerproduct\"\n\
+    num_output: 10\n\
+  }\n\
+  bottom: \"relu1\"\n\
+  top: \"ip2\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"prob\"\n\
+    type: \"softmax\"\n\
+  }\n\
+  bottom: \"ip2\"\n\
+  top: \"prob\"\n\
+}\n\
+layers {\n\
+  layer {\n\
+    name: \"loss\"\n\
+    type: \"multinomial_logistic_loss\"\n\
+  }\n\
+  bottom: \"prob\"\n\
+  bottom: \"label\"\n\
+}";
+
+}  // namespace caffe
+
+#endif
index 31bb919..ba76ed1 100644 (file)
@@ -57,22 +57,4 @@ TYPED_TEST(BlobSimpleTest, TestReshape) {
   EXPECT_EQ(this->blob_->count(), 120);
 }
 
-TYPED_TEST(BlobSimpleTest, TestCopyConstructor) {
-  Blob<TypeParam> source(2, 3, 4, 5);
-  FillerParameter filler_param;
-  UniformFiller<TypeParam> filler(filler_param);
-  filler.Fill(&source);
-  Blob<TypeParam> target(source);
-  const TypeParam* source_data = source.cpu_data();
-  const TypeParam* target_data = target.cpu_data();
-  EXPECT_EQ(target.num(), source.num());
-  EXPECT_EQ(target.channels(), source.channels());
-  EXPECT_EQ(target.height(), source.height());
-  EXPECT_EQ(target.width(), source.width());
-  EXPECT_EQ(target.count(), source.count());
-  for (int i = 0; i < source.count(); ++i) {
-    EXPECT_EQ(source_data[i], target_data[i]);
-  }
-}
-
 }
index dbaa7ba..0c34861 100644 (file)
@@ -62,7 +62,7 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
   // First, figure out what blobs we need to check against.
   vector<Blob<Dtype>*> blobs_to_check;
   for (int i = 0; i < layer.params().size(); ++i) {
-    blobs_to_check.push_back(&layer.params()[i]);
+    blobs_to_check.push_back(layer.params()[i].get());
   }
   if (check_bottom < 0) {
     for (int i = 0; i < bottom.size(); ++i) {
index 212b92f..3ccd34e 100644 (file)
@@ -15,7 +15,7 @@
 namespace caffe {
 
 extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-  
+
 template <typename Dtype>
 class InnerProductLayerTest : public ::testing::Test {
  protected:
diff --git a/src/caffe/test/test_net_proto.cpp b/src/caffe/test/test_net_proto.cpp
new file mode 100644 (file)
index 0000000..b328d3a
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cstring>
+#include <cuda_runtime.h>
+#include <google/protobuf/text_format.h>
+#include <gtest/gtest.h>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/test/lenet.hpp"
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+class NetProtoTest : public ::testing::Test {};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(NetProtoTest, Dtypes);
+
+TYPED_TEST(NetProtoTest, TestSetup) {
+  NetParameter net_param;
+  string lenet_string(kLENET);
+  // Load the network
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      lenet_string, &net_param));
+  // check if things are right
+  EXPECT_EQ(net_param.layers_size(), 9);
+  EXPECT_EQ(net_param.bottom_size(), 2);
+  EXPECT_EQ(net_param.top_size(), 0);
+
+  // Now, initialize a network using the parameter
+  shared_ptr<Blob<TypeParam> > data(new Blob<TypeParam>(10, 1, 28, 28));
+  shared_ptr<Blob<TypeParam> > label(new Blob<TypeParam>(10, 1, 1, 1));
+  vector<Blob<TypeParam>*> bottom_vec;
+  bottom_vec.push_back(data.get());
+  bottom_vec.push_back(label.get());
+
+  Net<TypeParam> caffe_net(net_param, bottom_vec);
+  EXPECT_EQ(caffe_net.layer_names().size(), 9);
+  EXPECT_EQ(caffe_net.blob_names().size(), 10);
+}
+
+}  // namespace caffe
index 3429618..a5d0c9f 100644 (file)
@@ -56,44 +56,6 @@ TYPED_TEST(PoolingLayerTest, TestSetup) {
   EXPECT_EQ(this->blob_top_->width(), 2);
 }
 
-TYPED_TEST(PoolingLayerTest, TestGPUMax) {
-  LayerParameter layer_param;
-  layer_param.set_kernelsize(3);
-  layer_param.set_stride(2);
-  layer_param.set_pool(LayerParameter_PoolMethod_MAX);
-  Caffe::set_mode(Caffe::CPU);
-  PoolingLayer<TypeParam> layer(layer_param);
-  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  Blob<TypeParam> blob_reference(*this->blob_top_);
-  Caffe::set_mode(Caffe::GPU);
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  for (int i = 0; i < blob_reference.count(); ++i) {
-    EXPECT_EQ(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i])
-        << "debug: index " << i;
-  }
-}
-
-TYPED_TEST(PoolingLayerTest, TestGPUAve) {
-  LayerParameter layer_param;
-  layer_param.set_kernelsize(3);
-  layer_param.set_stride(2);
-  layer_param.set_pool(LayerParameter_PoolMethod_AVE);
-  Caffe::set_mode(Caffe::CPU);
-  PoolingLayer<TypeParam> layer(layer_param);
-  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  Blob<TypeParam> blob_reference(*this->blob_top_);
-  Caffe::set_mode(Caffe::GPU);
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  for (int i = 0; i < blob_reference.count(); ++i) {
-    EXPECT_GE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] - 1e-4)
-        << "debug: index " << i;
-    EXPECT_LE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] + 1e-4)
-        << "debug: index " << i;
-  }
-}
-
 /*
 TYPED_TEST(PoolingLayerTest, PrintGPUBackward) {
   LayerParameter layer_param;