template <typename Dtype>
void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- std::vector<int> shape(4, 1);
- shape[0] = bottom[0]->shape(0);
- // Produces max_ind
- shape[2] = top_k_;
+ std::vector<int> shape(bottom[0]->num_axes(), 1);
if (has_axis_) {
// Produces max_ind or max_val per axis
shape = bottom[0]->shape();
shape[axis_] = top_k_;
- } else if (out_max_val_) {
- // Produces max_ind and max_val
- shape[1] = 2;
+ } else {
+ shape[0] = bottom[0]->shape(0);
+ // Produces max_ind
+ shape[2] = top_k_;
+ if (out_max_val_) {
+ // Produces max_ind and max_val
+ shape[1] = 2;
+ }
}
top[0]->Reshape(shape);
}
if (out_max_val_) {
if (has_axis_) {
// Produces max_val per axis
- top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
- bottom_data_vector[j].first;
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
+ = bottom_data_vector[j].first;
} else {
// Produces max_ind and max_val
- top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
- top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+ top_data[2 * i * top_k_ + j] = bottom_data_vector[j].second;
+ top_data[2 * i * top_k_ + top_k_ + j] = bottom_data_vector[j].first;
}
} else {
// Produces max_ind per axis
- top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
- bottom_data_vector[j].second;
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
+ = bottom_data_vector[j].second;
}
}
}
EXPECT_EQ(this->blob_top_->channels(), 2);
}
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxis) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(0);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(2), this->blob_bottom_->shape(2));
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxisNegativeIndexing) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(-2);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
+ EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxisMaxVal) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(2);
+ argmax_param->set_out_max_val(true);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
+ EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
TYPED_TEST(ArgMaxLayerTest, TestCPU) {
LayerParameter layer_param;
ArgMaxLayer<TypeParam> layer(layer_param);
}
}
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxis) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(0);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ int max_ind;
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[1]; ++i) {
+ for (int j = 0; j < shape[2]; ++j) {
+ for (int k = 0; k < shape[3]; ++k) {
+ max_ind = this->blob_top_->data_at(0, i, j, k);
+ max_val = this->blob_bottom_->data_at(max_ind, i, j, k);
+ EXPECT_GE(max_ind, 0);
+ EXPECT_LE(max_ind, shape[0]);
+ for (int l = 0; l < shape[0]; ++l) {
+ EXPECT_LE(this->blob_bottom_->data_at(l, i, j, k), max_val);
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxisTopK) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(2);
+ argmax_param->set_top_k(this->top_k_);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ int max_ind;
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[0]; ++i) {
+ for (int j = 0; j < shape[1]; ++j) {
+ for (int k = 0; k < shape[3]; ++k) {
+ for (int m = 0; m < this->top_k_; ++m) {
+ max_ind = this->blob_top_->data_at(i, j, m, k);
+ max_val = this->blob_bottom_->data_at(i, j, max_ind, k);
+ EXPECT_GE(max_ind, 0);
+ EXPECT_LE(max_ind, shape[2]);
+ int count = 0;
+ for (int l = 0; l < shape[2]; ++l) {
+ if (this->blob_bottom_->data_at(i, j, l, k) > max_val) {
+ ++count;
+ }
+ }
+ EXPECT_EQ(m, count);
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxisMaxValTopK) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(-1);
+ argmax_param->set_top_k(this->top_k_);
+ argmax_param->set_out_max_val(true);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[0]; ++i) {
+ for (int j = 0; j < shape[1]; ++j) {
+ for (int k = 0; k < shape[2]; ++k) {
+ for (int m = 0; m < this->top_k_; ++m) {
+ max_val = this->blob_top_->data_at(i, j, k, m);
+ int count = 0;
+ for (int l = 0; l < shape[3]; ++l) {
+ if (this->blob_bottom_->data_at(i, j, k, l) > max_val) {
+ ++count;
+ }
+ }
+ EXPECT_EQ(m, count);
+ }
+ }
+ }
+ }
+}
} // namespace caffe