template <typename Dtype>
class BilinearFillerTest : public ::testing::Test {
protected:
- BilinearFillerTest() : filler_param_() {}
- virtual void test_params(const int n) {
- this->blob_ = new Blob<Dtype>(1000, 2, n, n);
- this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int outer_num = this->blob_->count(0, 2);
- const int inner_num = this->blob_->count(2, 4);
- const Dtype* data = this->blob_->cpu_data();
- int f = ceil(this->blob_->width() / 2.);
- Dtype c = (this->blob_->width() - 1) / (2. * f);
+ BilinearFillerTest()
+ : blob_(new Blob<Dtype>()),
+ filler_param_() {
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ filler_.reset(new BilinearFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ CHECK_EQ(blob_->num_axes(), 4);
+ const int outer_num = blob_->count(0, 2);
+ const int inner_num = blob_->count(2, 4);
+ const Dtype* data = blob_->cpu_data();
+ int f = ceil(blob_->shape(3) / 2.);
+ Dtype c = (blob_->shape(3) - 1) / (2. * f);
for (int i = 0; i < outer_num; ++i) {
for (int j = 0; j < inner_num; ++j) {
- Dtype x = j % this->blob_->width();
- Dtype y = (j / this->blob_->width()) % this->blob_->height();
+ Dtype x = j % blob_->shape(3);
+ Dtype y = (j / blob_->shape(3)) % blob_->shape(2);
Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
const Dtype actual_value = data[i * inner_num + j];
EXPECT_NEAR(expected_value, actual_value, 0.01);
TYPED_TEST(BilinearFillerTest, TestFillOdd) {
const int n = 7;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
TYPED_TEST(BilinearFillerTest, TestFillEven) {
const int n = 6;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
} // namespace caffe