CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";
Dtype* data = blob->mutable_cpu_data();
int f = ceil(blob->width() / 2.);
- float c = (2 * f - 1 - f % 2) / (2. * f);
+ Dtype c = (blob->width() - 1) / (2. * f);
for (int i = 0; i < blob->count(); ++i) {
- float x = i % blob->width();
- float y = (i / blob->width()) % blob->height();
+ Dtype x = i % blob->width();
+ Dtype y = (i / blob->width()) % blob->height();
data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
}
CHECK_EQ(this->filler_param_.sparse(), -1)
const int count = this->blob_->count();
const TypeParam* data = this->blob_->cpu_data();
for (int i = 0; i < count; ++i) {
- EXPECT_GE(data[i], this->filler_param_.value());
+ EXPECT_EQ(data[i], this->filler_param_.value());
}
}
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}
+template <typename Dtype>
+class BilinearFillerTest : public ::testing::Test {
+ protected:
+ BilinearFillerTest() : filler_param_() {}
+ virtual void test_params(const int n) {
+ this->blob_ = new Blob<Dtype>(1000, 2, n, n);
+ this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
+ this->filler_->Fill(blob_);
+ EXPECT_TRUE(this->blob_);
+ const int outer_num = this->blob_->count(0, 2);
+ const int inner_num = this->blob_->count(2, 4);
+ const Dtype* data = this->blob_->cpu_data();
+ int f = ceil(this->blob_->width() / 2.);
+ Dtype c = (this->blob_->width() - 1) / (2. * f);
+ for (int i = 0; i < outer_num; ++i) {
+ for (int j = 0; j < inner_num; ++j) {
+ Dtype x = j % this->blob_->width();
+ Dtype y = (j / this->blob_->width()) % this->blob_->height();
+ Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
+ const Dtype actual_value = data[i * inner_num + j];
+ EXPECT_NEAR(expected_value, actual_value, 0.01);
+ }
+ }
+ }
+ virtual ~BilinearFillerTest() { delete blob_; }
+ Blob<Dtype>* blob_;
+ FillerParameter filler_param_;
+ shared_ptr<BilinearFiller<Dtype> > filler_;
+};
+
+TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);
+
+TYPED_TEST(BilinearFillerTest, TestFillOdd) {
+ const int n = 7;
+ this->test_params(n);
+}
+TYPED_TEST(BilinearFillerTest, TestFillEven) {
+ const int n = 6;
+ this->test_params(n);
+}
+
} // namespace caffe