BlobMathTest: fixes for numerical precision issues
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 16 Feb 2015 01:05:10 +0000 (17:05 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 16 Feb 2015 01:05:10 +0000 (17:05 -0800)
src/caffe/test/test_blob.cpp

index 84d84e8..e067806 100644 (file)
@@ -59,9 +59,12 @@ class BlobMathTest : public MultiDeviceTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;
  protected:
   BlobMathTest()
-      : blob_(new Blob<Dtype>(2, 3, 4, 5)) {}
+      : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+        epsilon_(1e-6) {}
+
   virtual ~BlobMathTest() { delete blob_; }
   Blob<Dtype>* const blob_;
+  Dtype epsilon_;
 };
 
 TYPED_TEST_CASE(BlobMathTest, TestDtypesAndDevices);
@@ -95,7 +98,8 @@ TYPED_TEST(BlobMathTest, TestSumOfSquares) {
   default:
     LOG(FATAL) << "Unknown device: " << TypeParam::device;
   }
-  EXPECT_FLOAT_EQ(expected_sumsq, this->blob_->sumsq_data());
+  EXPECT_NEAR(expected_sumsq, this->blob_->sumsq_data(),
+              this->epsilon_ * expected_sumsq);
   EXPECT_EQ(0, this->blob_->sumsq_diff());
 
   // Check sumsq_diff too.
@@ -112,9 +116,12 @@ TYPED_TEST(BlobMathTest, TestSumOfSquares) {
   default:
     LOG(FATAL) << "Unknown device: " << TypeParam::device;
   }
-  EXPECT_FLOAT_EQ(expected_sumsq, this->blob_->sumsq_data());
-  EXPECT_FLOAT_EQ(expected_sumsq * kDiffScaleFactor * kDiffScaleFactor,
-                  this->blob_->sumsq_diff());
+  EXPECT_NEAR(expected_sumsq, this->blob_->sumsq_data(),
+              this->epsilon_ * expected_sumsq);
+  const Dtype expected_sumsq_diff =
+      expected_sumsq * kDiffScaleFactor * kDiffScaleFactor;
+  EXPECT_NEAR(expected_sumsq_diff, this->blob_->sumsq_diff(),
+              this->epsilon_ * expected_sumsq_diff);
 }
 
 TYPED_TEST(BlobMathTest, TestAsum) {
@@ -146,7 +153,8 @@ TYPED_TEST(BlobMathTest, TestAsum) {
   default:
     LOG(FATAL) << "Unknown device: " << TypeParam::device;
   }
-  EXPECT_FLOAT_EQ(expected_asum, this->blob_->asum_data());
+  EXPECT_NEAR(expected_asum, this->blob_->asum_data(),
+              this->epsilon_ * expected_asum);
   EXPECT_EQ(0, this->blob_->asum_diff());
 
   // Check asum_diff too.
@@ -163,8 +171,11 @@ TYPED_TEST(BlobMathTest, TestAsum) {
   default:
     LOG(FATAL) << "Unknown device: " << TypeParam::device;
   }
-  EXPECT_FLOAT_EQ(expected_asum, this->blob_->asum_data());
-  EXPECT_FLOAT_EQ(expected_asum * kDiffScaleFactor, this->blob_->asum_diff());
+  EXPECT_NEAR(expected_asum, this->blob_->asum_data(),
+              this->epsilon_ * expected_asum);
+  const Dtype expected_diff_asum = expected_asum * kDiffScaleFactor;
+  EXPECT_NEAR(expected_diff_asum, this->blob_->asum_diff(),
+              this->epsilon_ * expected_diff_asum);
 }
 
 TYPED_TEST(BlobMathTest, TestScaleData) {
@@ -193,8 +204,8 @@ TYPED_TEST(BlobMathTest, TestScaleData) {
   }
   const Dtype kDataScaleFactor = 3;
   this->blob_->scale_data(kDataScaleFactor);
-  EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
-                  this->blob_->asum_data());
+  EXPECT_NEAR(asum_before_scale * kDataScaleFactor, this->blob_->asum_data(),
+              this->epsilon_ * asum_before_scale * kDataScaleFactor);
   EXPECT_EQ(0, this->blob_->asum_diff());
 
   // Check scale_diff too.
@@ -202,11 +213,13 @@ TYPED_TEST(BlobMathTest, TestScaleData) {
   const Dtype* data = this->blob_->cpu_data();
   caffe_cpu_scale(this->blob_->count(), kDataToDiffScaleFactor, data,
                   this->blob_->mutable_cpu_diff());
-  EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
-                  this->blob_->asum_data());
-  const Dtype diff_asum_before_scale = this->blob_->asum_diff();
-  EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor * kDataToDiffScaleFactor,
-                  diff_asum_before_scale);
+  const Dtype expected_asum_before_scale = asum_before_scale * kDataScaleFactor;
+  EXPECT_NEAR(expected_asum_before_scale, this->blob_->asum_data(),
+              this->epsilon_ * expected_asum_before_scale);
+  const Dtype expected_diff_asum_before_scale =
+      asum_before_scale * kDataScaleFactor * kDataToDiffScaleFactor;
+  EXPECT_NEAR(expected_diff_asum_before_scale, this->blob_->asum_diff(),
+              this->epsilon_ * expected_diff_asum_before_scale);
   switch (TypeParam::device) {
   case Caffe::CPU:
     this->blob_->mutable_cpu_diff();
@@ -219,10 +232,12 @@ TYPED_TEST(BlobMathTest, TestScaleData) {
   }
   const Dtype kDiffScaleFactor = 3;
   this->blob_->scale_diff(kDiffScaleFactor);
-  EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
-                  this->blob_->asum_data());
-  EXPECT_FLOAT_EQ(diff_asum_before_scale * kDiffScaleFactor,
-                  this->blob_->asum_diff());
+  EXPECT_NEAR(asum_before_scale * kDataScaleFactor, this->blob_->asum_data(),
+              this->epsilon_ * asum_before_scale * kDataScaleFactor);
+  const Dtype expected_diff_asum =
+      expected_diff_asum_before_scale * kDiffScaleFactor;
+  EXPECT_NEAR(expected_diff_asum, this->blob_->asum_diff(),
+              this->epsilon_ * expected_diff_asum);
 }
 
 }  // namespace caffe