Generalise ArgMaxLayerTest bottom blob shape
authorTim Meinhardt <meinhardt.tim@gmail.com>
Tue, 15 Sep 2015 14:57:37 +0000 (16:57 +0200)
committerTim Meinhardt <meinhardt.tim@gmail.com>
Fri, 25 Sep 2015 10:06:00 +0000 (12:06 +0200)
src/caffe/test/test_argmax_layer.cpp

index 895c3d3..d3018f9 100644 (file)
@@ -16,7 +16,7 @@ template <typename Dtype>
 class ArgMaxLayerTest : public CPUDeviceTest<Dtype> {
  protected:
   ArgMaxLayerTest()
-      : blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)),
+      : blob_bottom_(new Blob<Dtype>(10, 10, 20, 20)),
         blob_top_(new Blob<Dtype>()),
         top_k_(5) {
     Caffe::set_random_seed(1701);
@@ -112,6 +112,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
   // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
   int max_ind;
   TypeParam max_val;
   int num = this->blob_bottom_->num();
@@ -121,10 +122,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
     EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim);
     for (int j = 0; j < this->top_k_; ++j) {
       max_ind = this->blob_top_->data_at(i, 0, j, 0);
-      max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0);
+      max_val = bottom_data[i * dim + max_ind];
       int count = 0;
       for (int k = 0; k < dim; ++k) {
-        if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+        if (bottom_data[i * dim + k] > max_val) {
           ++count;
         }
       }
@@ -142,6 +143,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
   // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
   int max_ind;
   TypeParam max_val;
   int num = this->blob_bottom_->num();
@@ -152,10 +154,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
     for (int j = 0; j < this->top_k_; ++j) {
       max_ind = this->blob_top_->data_at(i, 0, j, 0);
       max_val = this->blob_top_->data_at(i, 1, j, 0);
-      EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val);
+      EXPECT_EQ(bottom_data[i * dim + max_ind], max_val);
       int count = 0;
       for (int k = 0; k < dim; ++k) {
-        if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+        if (bottom_data[i * dim + k] > max_val) {
           ++count;
         }
       }