bilinear filter test refactor
authorNoiredd <snowball91b@gmail.com>
Fri, 9 Mar 2018 12:14:32 +0000 (13:14 +0100)
committerWook Song <wook16.song@samsung.com>
Thu, 23 Jan 2020 13:50:37 +0000 (22:50 +0900)
src/caffe/test/test_filler.cpp

index 3ecec37..1e6b5c2 100644 (file)
@@ -500,21 +500,25 @@ TYPED_TEST(MSRAFillerTest, TestFill5D) {
 template <typename Dtype>
 class BilinearFillerTest : public ::testing::Test {
  protected:
-  BilinearFillerTest() : filler_param_() {}
-  virtual void test_params(const int n) {
-    this->blob_ = new Blob<Dtype>(1000, 2, n, n);
-    this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
-    this->filler_->Fill(blob_);
-    EXPECT_TRUE(this->blob_);
-    const int outer_num = this->blob_->count(0, 2);
-    const int inner_num = this->blob_->count(2, 4);
-    const Dtype* data = this->blob_->cpu_data();
-    int f = ceil(this->blob_->width() / 2.);
-    Dtype c = (this->blob_->width() - 1) / (2. * f);
+  BilinearFillerTest()
+    : blob_(new Blob<Dtype>()),
+      filler_param_() {
+  }
+  virtual void test_params(const vector<int>& shape) {
+    EXPECT_TRUE(blob_);
+    blob_->Reshape(shape);
+    filler_.reset(new BilinearFiller<Dtype>(filler_param_));
+    filler_->Fill(blob_);
+    CHECK_EQ(blob_->num_axes(), 4);
+    const int outer_num = blob_->count(0, 2);
+    const int inner_num = blob_->count(2, 4);
+    const Dtype* data = blob_->cpu_data();
+    int f = ceil(blob_->shape(3) / 2.);
+    Dtype c = (blob_->shape(3) - 1) / (2. * f);
     for (int i = 0; i < outer_num; ++i) {
       for (int j = 0; j < inner_num; ++j) {
-        Dtype x = j % this->blob_->width();
-        Dtype y = (j / this->blob_->width()) % this->blob_->height();
+        Dtype x = j % blob_->shape(3);
+        Dtype y = (j / blob_->shape(3)) % blob_->shape(2);
         Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
         const Dtype actual_value = data[i * inner_num + j];
         EXPECT_NEAR(expected_value, actual_value, 0.01);
@@ -531,11 +535,21 @@ TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);
 
 TYPED_TEST(BilinearFillerTest, TestFillOdd) {
   const int n = 7;
-  this->test_params(n);
+  vector<int> blob_shape;
+  blob_shape.push_back(1000);
+  blob_shape.push_back(2);
+  blob_shape.push_back(n);
+  blob_shape.push_back(n);
+  this->test_params(blob_shape);
 }
 TYPED_TEST(BilinearFillerTest, TestFillEven) {
   const int n = 6;
-  this->test_params(n);
+  vector<int> blob_shape;
+  blob_shape.push_back(1000);
+  blob_shape.push_back(2);
+  blob_shape.push_back(n);
+  blob_shape.push_back(n);
+  this->test_params(blob_shape);
 }
 
 }  // namespace caffe