use a Blob<int> instead of a SyncedMemory to store max_idx_
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 24 May 2014 23:27:38 +0000 (16:27 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sun, 25 May 2014 01:43:10 +0000 (18:43 -0700)
include/caffe/vision_layers.hpp
src/caffe/blob.cpp
src/caffe/layers/pooling_layer.cpp
src/caffe/layers/pooling_layer.cu

index c0bbd37..a43f7f9 100644 (file)
@@ -350,7 +350,7 @@ class PoolingLayer : public Layer<Dtype> {
   int pooled_height_;
   int pooled_width_;
   Blob<Dtype> rand_idx_;
-  shared_ptr<SyncedMemory> max_idx_;
+  shared_ptr<Blob<int> > max_idx_;
 };
 
 /* SoftmaxLayer
index 444e9cf..597e848 100644 (file)
@@ -108,6 +108,14 @@ void Blob<Dtype>::ShareDiff(const Blob& other) {
   diff_ = other.diff();
 }
 
+template <>
+void Blob<int>::Update() {
+  // The "update" method is used for parameter blobs in a Net, which are stored
+  // as Blob<float>s or Blob<double>s -- hence we do not define it for
+  // Blob<int>s.
+  NOT_IMPLEMENTED;
+}
+
 template <typename Dtype>
 void Blob<Dtype>::Update() {
   // We will perform update based on where the data is located.
@@ -201,6 +209,7 @@ void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
 }
 
 INSTANTIATE_CLASS(Blob);
+template class Blob<int>;
 
 }  // namespace caffe
 
index 5c740de..3d216de 100644 (file)
@@ -40,7 +40,8 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   // If max pooling, we will initialize the vector index part.
   if (this->layer_param_.pooling_param().pool() ==
       PoolingParameter_PoolMethod_MAX) {
-    max_idx_.reset(new SyncedMemory((*top)[0]->count() * sizeof(int)));
+    max_idx_.reset(new Blob<int>(bottom[0]->num(), channels_,
+                                 pooled_height_, pooled_width_));
   }
   // If stochastic pooling, we will initialize the random index part.
   if (this->layer_param_.pooling_param().pool() ==
@@ -64,7 +65,7 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   switch (this->layer_param_.pooling_param().pool()) {
   case PoolingParameter_PoolMethod_MAX:
   // Initialize
-    mask = static_cast<int*>(max_idx_->mutable_cpu_data());
+    mask = max_idx_->mutable_cpu_data();
     for (int i = 0; i < top_count; ++i) {
       top_data[i] = -FLT_MAX;
       mask[i] = -1;
@@ -154,7 +155,7 @@ void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   switch (this->layer_param_.pooling_param().pool()) {
   case PoolingParameter_PoolMethod_MAX:
     // The main loop
-    mask = static_cast<const int*>(max_idx_->cpu_data());
+    mask = max_idx_->cpu_data();
     for (int n = 0; n < top[0]->num(); ++n) {
       for (int c = 0; c < channels_; ++c) {
         for (int ph = 0; ph < pooled_height_; ++ph) {
index cb3bfab..49e5d77 100644 (file)
@@ -153,7 +153,7 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   int* mask;
   switch (this->layer_param_.pooling_param().pool()) {
   case PoolingParameter_PoolMethod_MAX:
-    mask = static_cast<int*>(max_idx_->mutable_gpu_data());
+    mask = max_idx_->mutable_gpu_data();
     // NOLINT_NEXT_LINE(whitespace/operators)
     MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
         count, bottom_data, bottom[0]->num(), channels_,
@@ -303,7 +303,7 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   const int* mask;
   switch (this->layer_param_.pooling_param().pool()) {
   case PoolingParameter_PoolMethod_MAX:
-    mask = static_cast<const int*>(max_idx_->gpu_data());
+    mask = max_idx_->gpu_data();
     caffe_gpu_set(count, Dtype(0.), bottom_diff);
     // NOLINT_NEXT_LINE(whitespace/operators)
     MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(