Fixed bilinear filler, added tests
authorNoiredd <snowball91b@gmail.com>
Mon, 2 Oct 2017 12:39:31 +0000 (14:39 +0200)
committerNoiredd <snowball91b@gmail.com>
Mon, 2 Oct 2017 13:36:07 +0000 (15:36 +0200)
include/caffe/filler.hpp
src/caffe/test/test_filler.cpp

index dad9ad4..bb92ded 100644 (file)
@@ -250,10 +250,10 @@ class BilinearFiller : public Filler<Dtype> {
     CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";
     Dtype* data = blob->mutable_cpu_data();
     int f = ceil(blob->width() / 2.);
-    float c = (2 * f - 1 - f % 2) / (2. * f);
+    Dtype c = (blob->width() - 1) / (2. * f);
     for (int i = 0; i < blob->count(); ++i) {
-      float x = i % blob->width();
-      float y = (i / blob->width()) % blob->height();
+      Dtype x = i % blob->width();
+      Dtype y = (i / blob->width()) % blob->height();
       data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
     }
     CHECK_EQ(this->filler_param_.sparse(), -1)
index 26e9b21..f84d707 100644 (file)
@@ -29,7 +29,7 @@ TYPED_TEST(ConstantFillerTest, TestFill) {
   const int count = this->blob_->count();
   const TypeParam* data = this->blob_->cpu_data();
   for (int i = 0; i < count; ++i) {
-    EXPECT_GE(data[i], this->filler_param_.value());
+    EXPECT_EQ(data[i], this->filler_param_.value());
   }
 }
 
@@ -238,4 +238,45 @@ TYPED_TEST(MSRAFillerTest, TestFillAverage) {
   this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
 }
 
+template <typename Dtype>
+class BilinearFillerTest : public ::testing::Test {
+ protected:
+  BilinearFillerTest() : filler_param_() {}
+  virtual void test_params(const int n) {
+    this->blob_ = new Blob<Dtype>(1000, 2, n, n);
+    this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
+    this->filler_->Fill(blob_);
+    EXPECT_TRUE(this->blob_);
+    const int outer_num = this->blob_->count(0, 2);
+    const int inner_num = this->blob_->count(2, 4);
+    const Dtype* data = this->blob_->cpu_data();
+    int f = ceil(this->blob_->width() / 2.);
+    Dtype c = (this->blob_->width() - 1) / (2. * f);
+    for (int i = 0; i < outer_num; ++i) {
+      for (int j = 0; j < inner_num; ++j) {
+        Dtype x = j % this->blob_->width();
+        Dtype y = (j / this->blob_->width()) % this->blob_->height();
+        Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
+        const Dtype actual_value = data[i * inner_num + j];
+        EXPECT_NEAR(expected_value, actual_value, 0.01);
+      }
+    }
+  }
+  virtual ~BilinearFillerTest() { delete blob_; }
+  Blob<Dtype>* blob_;
+  FillerParameter filler_param_;
+  shared_ptr<BilinearFiller<Dtype> > filler_;
+};
+
+TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);
+
+TYPED_TEST(BilinearFillerTest, TestFillOdd) {
+  const int n = 7;
+  this->test_params(n);
+}
+TYPED_TEST(BilinearFillerTest, TestFillEven) {
+  const int n = 6;
+  this->test_params(n);
+}
+
 }  // namespace caffe