Reject float literals with two leading signs
authorDavid Neto <dneto@google.com>
Wed, 10 Feb 2016 16:46:05 +0000 (11:46 -0500)
committerDavid Neto <dneto@google.com>
Thu, 11 Feb 2016 19:11:16 +0000 (14:11 -0500)
E.g. --1 should be rejected.

include/util/hex_float.h
test/HexFloat.cpp
test/TextToBinary.Constant.cpp

index 2241b99..d174739 100644 (file)
@@ -364,11 +364,11 @@ class HexFloat {
       // the significand is not zero.
       significand_is_zero = false;
       significand |= first_exponent_bit;
-      significand  = static_cast<uint_type>(significand >> 1);
+      significand = static_cast<uint_type>(significand >> 1);
     }
 
     while (exponent < min_exponent) {
-      significand  = static_cast<uint_type>(significand >> 1);
+      significand = static_cast<uint_type>(significand >> 1);
       ++exponent;
     }
 
@@ -585,9 +585,9 @@ class HexFloat {
     if (is_nan) {
       typename other_T::uint_type shifted_significand;
       shifted_significand = static_cast<typename other_T::uint_type>(
-          negatable_left_shift<static_cast<int_type>(other_T::num_fraction_bits) -
-                               static_cast<int_type>(
-                                   num_fraction_bits)>::val(significand));
+          negatable_left_shift<
+              static_cast<int_type>(other_T::num_fraction_bits) -
+              static_cast<int_type>(num_fraction_bits)>::val(significand));
 
       // We are some sort of Nan. We try to keep the bit-pattern of the Nan
       // as close as possible. If we had to shift off bits so we are 0, then we
@@ -680,7 +680,7 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
     }
     // Since this is denormalized, we have to consume the leading 1 since it
     // will end up being implicit.
-    fraction  = static_cast<uint_type>(fraction << 1);  // eat the leading 1
+    fraction = static_cast<uint_type>(fraction << 1);  // eat the leading 1
     fraction &= HF::fraction_represent_mask;
   }
 
@@ -689,7 +689,7 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
   // fractional part.
   while (fraction_nibbles > 0 && (fraction & 0xF) == 0) {
     // Shift off any trailing values;
-    fraction = static_cast<uint_type>(fraction >>  4);
+    fraction = static_cast<uint_type>(fraction >> 4);
     --fraction_nibbles;
   }
 
@@ -711,13 +711,36 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
   return os;
 }
 
+// Returns true if negate_value is true and the next character on the
+// input stream is a plus or minus sign.  In that case we also set the fail bit
+// on the stream and set the value to the zero value for its type.
+template <typename T, typename Traits>
+inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value,
+                                        HexFloat<T, Traits>& value) {
+  if (negate_value) {
+    auto next_char = is.peek();
+    if (next_char == '-' || next_char == '+') {
+      // Fail the parse.  Emulate standard behaviour by setting the value to
+      // the zero value, and set the fail bit on the stream.
+      value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type(0));
+      is.setstate(std::ios_base::failbit);
+      return true;
+    }
+  }
+  return false;
+}
+
 // Parses a floating point number from the given stream and stores it into the
 // value parameter.
-// If the negate_value parameter is true then the number is negated before
-// it is stored into the value parameter.
+// If negate_value is true then the number may not have a leading minus or
+// plus, and if it successfully parses, then the number is negated before
+// being stored into the value parameter.
 template <typename T, typename Traits>
 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
                                       HexFloat<T, Traits>& value) {
+  if (RejectParseDueToLeadingSign(is, negate_value, value)) {
+    return is;
+  }
   T val;
   is >> val;
   if (negate_value) {
@@ -729,14 +752,20 @@ inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
 
 // Specialization of ParseNormalFloat for FloatProxy<Float16> values.
 // This will parse the float as it were a 32-bit floating point number,
-// and then round it down to fit into a Float16 value. 
+// and then round it down to fit into a Float16 value.
 // The number is rounded towards zero.
 // Any floating point number that is too large will be rounded to +- infinity.
+// If negate_value is true then the number may not have a leading minus or
+// plus, and if it successfully parses, then the number is negated before
+// being stored into the value parameter.
 template <>
 inline std::istream&
 ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
     std::istream& is, bool negate_value,
     HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
+  if (RejectParseDueToLeadingSign(is, negate_value, value)) {
+    return is;
+  }
   float f;
   is >> f;
   if (negate_value) {
@@ -841,8 +870,8 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
         if (bits_written) {
           // If we are here the bits represented belong in the fractional
           // part of the float, and we have to adjust the exponent accordingly.
-          fraction =
-                 static_cast<uint_type>(fraction |
+          fraction = static_cast<uint_type>(
+              fraction |
               static_cast<uint_type>(
                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
           exponent = static_cast<int_type>(exponent + 1);
@@ -873,8 +902,8 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
           // integer.
           exponent = static_cast<int_type>(exponent - 1);
         } else {
-          fraction =
-              static_cast<uint_type>(fraction |
+          fraction = static_cast<uint_type>(
+              fraction |
               static_cast<uint_type>(
                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
         }
@@ -989,8 +1018,9 @@ std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
   return os;
 }
 
-template<>
-inline std::ostream& operator<< <Float16>(std::ostream& os, const FloatProxy<Float16>& value) {
+template <>
+inline std::ostream& operator<<<Float16>(std::ostream& os,
+                                         const FloatProxy<Float16>& value) {
   os << HexFloat<FloatProxy<Float16>>(value);
   return os;
 }
index 3026226..dc538a1 100644 (file)
 namespace {
 using ::testing::Eq;
 using spvutils::BitwiseCast;
+using spvutils::Float16;
 using spvutils::FloatProxy;
+using spvutils::HexFloat;
+using spvutils::ParseNormalFloat;
 
 // In this file "encode" means converting a number into a string,
 // and "decode" means converting a string into a number.
@@ -1008,5 +1011,113 @@ TEST(HexFloatOperationTests, NanTests) {
   }
 }
 
+// A test case for parsing good and bad HexFloat<FloatProxy<T>> literals.
+template <typename T>
+struct FloatParseCase {
+  std::string literal;
+  bool negate_value;
+  bool expect_success;
+  HexFloat<FloatProxy<T>> expected_value;
+};
+
+using ParseNormalFloatTest = ::testing::TestWithParam<FloatParseCase<float>>;
+
+TEST_P(ParseNormalFloatTest, Samples) {
+  std::stringstream input(GetParam().literal);
+  HexFloat<FloatProxy<float>> parsed_value(0.0f);
+  ParseNormalFloat(input, GetParam().negate_value, parsed_value);
+  if (GetParam().expect_success) {
+    EXPECT_FALSE(input.fail()) << " literal: " << GetParam().literal
+                               << " negate: " << GetParam().negate_value;
+    EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value()));
+  } else {
+    EXPECT_TRUE(input.fail()) << " literal: " << GetParam().literal
+                              << " negate: " << GetParam().negate_value;
+  }
+}
+
+// Returns a FloatParseCase with expected failure.
+template <typename T>
+FloatParseCase<T> BadFloatParseCase(std::string literal, bool negate_value) {
+  HexFloat<FloatProxy<T>> dummy_value(0.0f);
+  return FloatParseCase<T>{literal, negate_value, false, dummy_value};
+}
+
+// Returns a FloatParseCase that should successfully parse to a given value.
+template <typename T>
+FloatParseCase<T> GoodFloatParseCase(std::string literal, bool negate_value,
+                                     T expected_value) {
+  HexFloat<FloatProxy<T>> proxy_expected_value(expected_value);
+  return FloatParseCase<T>{literal, negate_value, true, proxy_expected_value};
+}
+
+INSTANTIATE_TEST_CASE_P(FloatParse, ParseNormalFloatTest,
+                        ::testing::ValuesIn(std::vector<FloatParseCase<float>>{
+                            // Failing cases due to trivially incorrect syntax.
+                            BadFloatParseCase<float>("abc", false),
+                            BadFloatParseCase<float>("abc", true),
+
+                            // Valid cases.
+                            GoodFloatParseCase<float>("0", false, 0.0f),
+                            GoodFloatParseCase<float>("0.0", false, 0.0f),
+                            GoodFloatParseCase<float>("-0.0", false, -0.0f),
+                            GoodFloatParseCase<float>("2.0", false, 2.0f),
+                            GoodFloatParseCase<float>("-2.0", false, -2.0f),
+                            GoodFloatParseCase<float>("+2.0", false, 2.0f),
+                            // Cases with negate_value being true.
+                            GoodFloatParseCase<float>("0.0", true, -0.0f),
+                            GoodFloatParseCase<float>("2.0", true, -2.0f),
+
+                            // When negate_value is true, we should not accept a
+                            // leading minus or plus.
+                            BadFloatParseCase<float>("-0.0", true),
+                            BadFloatParseCase<float>("-2.0", true),
+                            BadFloatParseCase<float>("+0.0", true),
+                            BadFloatParseCase<float>("+2.0", true),
+                        }));
+
+using ParseNormalFloat16Test =
+    ::testing::TestWithParam<FloatParseCase<Float16>>;
+
+TEST_P(ParseNormalFloat16Test, Samples) {
+  std::stringstream input(GetParam().literal);
+  HexFloat<FloatProxy<Float16>> parsed_value(0.0f);
+  ParseNormalFloat(input, GetParam().negate_value, parsed_value);
+  if (GetParam().expect_success) {
+    EXPECT_FALSE(input.fail()) << " literal: " << GetParam().literal
+                               << " negate: " << GetParam().negate_value;
+    EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value()));
+  } else {
+    EXPECT_TRUE(input.fail()) << " literal: " << GetParam().literal
+                              << " negate: " << GetParam().negate_value;
+  }
+}
+
+INSTANTIATE_TEST_CASE_P(
+    Float16Parse, ParseNormalFloat16Test,
+    ::testing::ValuesIn(std::vector<FloatParseCase<Float16>>{
+        // Failing cases due to trivially incorrect syntax.
+        BadFloatParseCase<Float16>("abc", false),
+        BadFloatParseCase<Float16>("abc", true),
+
+        // Valid cases.
+        GoodFloatParseCase<Float16>("0", false, uint16_t{0}),
+        GoodFloatParseCase<Float16>("0.0", false, uint16_t{0}),
+        GoodFloatParseCase<Float16>("-0.0", false, uint16_t{0x8000}),
+        GoodFloatParseCase<Float16>("2.0", false, uint16_t{0x4000}),
+        GoodFloatParseCase<Float16>("-2.0", false, uint16_t{0xc000}),
+        GoodFloatParseCase<Float16>("+2.0", false, uint16_t{0x4000}),
+        // Cases with negate_value being true.
+        GoodFloatParseCase<Float16>("0.0", true, uint16_t{0x8000}),
+        GoodFloatParseCase<Float16>("2.0", true, uint16_t{0xc000}),
+
+        // When negate_value is true, we should not accept a leading minus or
+        // plus.
+        BadFloatParseCase<Float16>("-0.0", true),
+        BadFloatParseCase<Float16>("-2.0", true),
+        BadFloatParseCase<Float16>("+0.0", true),
+        BadFloatParseCase<Float16>("+2.0", true),
+    }));
+
 // TODO(awoloszyn): Add fp16 tests and HexFloatTraits.
 }
index 4fabba7..be85eca 100644 (file)
@@ -279,6 +279,48 @@ INSTANTIATE_TEST_CASE_P(
     }));
 // clang-format on
 
+// A test case for invalid floating point literals.
+struct InvalidFloatConstantCase {
+  uint32_t width;
+  std::string literal;
+};
+
+using OpConstantInvalidFloatConstant = spvtest::TextToBinaryTestBase<
+    ::testing::TestWithParam<InvalidFloatConstantCase>>;
+
+TEST_P(OpConstantInvalidFloatConstant, Samples) {
+  // Check both kinds of instructions that take literal floats.
+  for (const auto& instruction : {"OpConstant", "OpSpecConstant"}) {
+    std::stringstream input;
+    input << "%1 = OpTypeFloat " << GetParam().width << "\n"
+          << "%2 = " << instruction << " %1 " << GetParam().literal;
+    std::stringstream expected_error;
+    expected_error << "Invalid " << GetParam().width
+                   << "-bit float literal: " << GetParam().literal;
+    EXPECT_THAT(CompileFailure(input.str()), Eq(expected_error.str()));
+  }
+}
+
+INSTANTIATE_TEST_CASE_P(
+    TextToBinaryInvalidFloatConstant, OpConstantInvalidFloatConstant,
+    ::testing::ValuesIn(std::vector<InvalidFloatConstantCase>{
+        {16, "abc"},
+        {16, "--1"},
+        {16, "-+1"},
+        {16, "+-1"},
+        {16, "++1"},
+        {32, "abc"},
+        {32, "--1"},
+        {32, "-+1"},
+        {32, "+-1"},
+        {32, "++1"},
+        {64, "abc"},
+        {64, "--1"},
+        {64, "-+1"},
+        {64, "+-1"},
+        {64, "++1"},
+    }));
+
 using OpConstantInvalidTypeTest =
     spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;