[nnc] Fix float equality check in the interpreter test (#915)
authorDmitry Mozolev/AI Tools Lab /SRR/Engineer/삼성전자 <d.mozolev@samsung.com>
Tue, 7 Aug 2018 07:34:09 +0000 (10:34 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 7 Aug 2018 07:34:09 +0000 (10:34 +0300)
The problem was that the custom float comparing function wasn't
checking for some corner cases, like floats having a different
sign or being +0 and -0 (although this is very unlikely of course).

Signed-off-by: Dmitry Mozolev <d.mozolev@samsung.com>
contrib/nnc/libs/backend/interpreter/test/src/op_info_util.cpp

index 9c826b7..d504483 100644 (file)
@@ -1,9 +1,11 @@
+#include <limits>
+#include <cstdint>
+
 #include "gtest/gtest.h"
 
 #include "nnc/core/linalg/Tensor.h"
 #include "nnc/core/linalg/ShapeRange.h"
 #include "nncc/core/ADT/tensor/Shape.h"
-
 #include "op_info_util.h"
 
 std::shared_ptr<TensorVariant> getTensor(const opinfo::Tensor* t)
@@ -80,13 +82,25 @@ __attribute__ ((unused)) void printTensor(const TensorVariant& lhs)
 
 /** @brief Custom float comparator.
  * It is supposed to be equivalent to GTest's ASSERT_FLOAT_EQ when allowedUlpsDiff is 4.
+ * Reminder: if the integer representations of two same-sign floats are subtracted then
+ * the absolute value of the result is equal to one plus the number of representable floats
+ * between them. This difference tells us how many ULPs the numbers differ by.
+ * @usage This function only works if float implementation conforms to IEEE-754.
  */
 static inline ::testing::AssertionResult areFloatsEqual(float f1, float f2, int allowedUlpsDiff)
 {
-  auto intRepr1 = *reinterpret_cast<int*>(&f1);
-  auto intRepr2 = *reinterpret_cast<int*>(&f2);
+  auto intRepr1 = *reinterpret_cast<int32_t*>(&f1);
+  auto intRepr2 = *reinterpret_cast<int32_t*>(&f2);
+
+  if ((intRepr1 < 0) != (intRepr2 < 0))
+  {
+    if (f1 == f2) // Checking for +0 and -0
+      return ::testing::AssertionSuccess();
+    else
+      return ::testing::AssertionFailure() << "Different signs";
+  }
 
-  int ulpsDiff = std::abs(intRepr1 - intRepr2);
+  auto ulpsDiff = std::abs(intRepr1 - intRepr2);
 
   if (ulpsDiff <= allowedUlpsDiff)
     return ::testing::AssertionSuccess();
@@ -99,6 +113,8 @@ void assertTensorEq(const TensorVariant &lhs, const TensorVariant &rhs)
   using nncc::contrib::core::data::ShapeRange;
   using nncc::contrib::core::data::Tensor;
 
+  const int GTEST_FLOAT_EQ_ULP = 4;
+
   Tensor<float> lhsAccessor(lhs);
   Tensor<float> rhsAccessor(rhs);
 
@@ -106,7 +122,7 @@ void assertTensorEq(const TensorVariant &lhs, const TensorVariant &rhs)
 
   for(auto& idx : ShapeRange(lhsAccessor.getShape()))
   {
-    ASSERT_TRUE(areFloatsEqual(lhsAccessor.at(idx), rhsAccessor.at(idx), 4));
+    ASSERT_TRUE(areFloatsEqual(lhsAccessor.at(idx), rhsAccessor.at(idx), GTEST_FLOAT_EQ_ULP));
   }
 }