Implement ArgMaxLayer forward_cpu and reshape for axis param
authorTim Meinhardt <meinhardt.tim@gmail.com>
Tue, 15 Sep 2015 14:56:16 +0000 (16:56 +0200)
committerTim Meinhardt <meinhardt.tim@gmail.com>
Fri, 25 Sep 2015 10:05:54 +0000 (12:05 +0200)
src/caffe/layers/argmax_layer.cpp

index dad3d08..18ff5f5 100644 (file)
@@ -33,13 +33,19 @@ void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 template <typename Dtype>
 void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
-  if (out_max_val_) {
+  std::vector<int> shape(4, 1);
+  shape[0] = bottom[0]->shape(0);
+  // Produces max_ind
+  shape[2] = top_k_;
+  if (has_axis_) {
+    // Produces max_ind or max_val per axis
+    shape = bottom[0]->shape();
+    shape[axis_] = top_k_;
+  } else if (out_max_val_) {
     // Produces max_ind and max_val
-    top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1);
-  } else {
-    // Produces only max_ind
-    top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1);
+    shape[1] = 2;
   }
+  top[0]->Reshape(shape);
 }
 
 template <typename Dtype>
@@ -47,23 +53,40 @@ void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     const vector<Blob<Dtype>*>& top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = top[0]->mutable_cpu_data();
-  int num = bottom[0]->num();
-  int dim = bottom[0]->count() / bottom[0]->num();
+  int dim, axis_dist;
+  if (has_axis_) {
+    dim = bottom[0]->shape(axis_);
+    // Distance between values of axis in blob
+    axis_dist = bottom[0]->count(axis_) / dim;
+  } else {
+    dim = bottom[0]->count(1);
+    axis_dist = 1;
+  }
+  int num = bottom[0]->count() / dim;
+  std::vector<std::pair<Dtype, int> > bottom_data_vector(dim);
   for (int i = 0; i < num; ++i) {
-    std::vector<std::pair<Dtype, int> > bottom_data_vector;
     for (int j = 0; j < dim; ++j) {
-      bottom_data_vector.push_back(
-          std::make_pair(bottom_data[i * dim + j], j));
+      bottom_data_vector[j] = std::make_pair(
+        bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
     }
     std::partial_sort(
         bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
         bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
     for (int j = 0; j < top_k_; ++j) {
-      top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
-    }
-    if (out_max_val_) {
-      for (int j = 0; j < top_k_; ++j) {
-        top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+      if (out_max_val_) {
+        if (has_axis_) {
+          // Produces max_val per axis
+          top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
+            bottom_data_vector[j].first;
+        } else {
+          // Produces max_ind and max_val
+          top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
+          top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+        }
+      } else {
+        // Produces max_ind per axis
+        top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
+          bottom_data_vector[j].second;
       }
     }
   }