From c77d5e5156f94720c1decd13f7f87fe78df9d4eb Mon Sep 17 00:00:00 2001 From: Tim Meinhardt Date: Tue, 15 Sep 2015 16:56:16 +0200 Subject: [PATCH] Implement ArgMaxLayer forward_cpu and reshape for axis param --- src/caffe/layers/argmax_layer.cpp | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index dad3d08..18ff5f5 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -33,13 +33,19 @@ void ArgMaxLayer::LayerSetUp(const vector*>& bottom, template void ArgMaxLayer::Reshape(const vector*>& bottom, const vector*>& top) { - if (out_max_val_) { + std::vector shape(4, 1); + shape[0] = bottom[0]->shape(0); + // Produces max_ind + shape[2] = top_k_; + if (has_axis_) { + // Produces max_ind or max_val per axis + shape = bottom[0]->shape(); + shape[axis_] = top_k_; + } else if (out_max_val_) { // Produces max_ind and max_val - top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1); - } else { - // Produces only max_ind - top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1); + shape[1] = 2; } + top[0]->Reshape(shape); } template @@ -47,23 +53,40 @@ void ArgMaxLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); + int dim, axis_dist; + if (has_axis_) { + dim = bottom[0]->shape(axis_); + // Distance between values of axis in blob + axis_dist = bottom[0]->count(axis_) / dim; + } else { + dim = bottom[0]->count(1); + axis_dist = 1; + } + int num = bottom[0]->count() / dim; + std::vector > bottom_data_vector(dim); for (int i = 0; i < num; ++i) { - std::vector > bottom_data_vector; for (int j = 0; j < dim; ++j) { - bottom_data_vector.push_back( - std::make_pair(bottom_data[i * dim + j], j)); + bottom_data_vector[j] = std::make_pair( + bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j); } std::partial_sort( bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, bottom_data_vector.end(), std::greater >()); for (int j = 0; j < top_k_; ++j) { - top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second; - } - if (out_max_val_) { - for (int j = 0; j < top_k_; ++j) { - top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first; + if (out_max_val_) { + if (has_axis_) { + // Produces max_val per axis + top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = + bottom_data_vector[j].first; + } else { + // Produces max_ind and max_val + top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second; + top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first; + } + } else { + // Produces max_ind per axis + top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = + bottom_data_vector[j].second; } } } -- 2.7.4