Support 16-bit float in assembler and disassembler
authorAndrew Woloszyn <awoloszyn@google.com>
Fri, 8 Jan 2016 14:54:42 +0000 (09:54 -0500)
committerDavid Neto <dneto@google.com>
Fri, 8 Jan 2016 15:48:39 +0000 (10:48 -0500)
This adds half-precision constants to spirv-tools.
16-bit floats are always disassembled into hex-float format,
but can be assembled from floating point or hex-float inputs.

README.md
include/util/hex_float.h
source/disassemble.cpp
source/text_handler.cpp
test/TestFixture.h
test/TextToBinary.Constant.cpp
test/TextToBinary.ControlFlow.cpp
test/TextToBinary.cpp
test/UnitSPIRV.cpp
test/UnitSPIRV.h

index 4c8d533..11811fe 100644 (file)
--- a/README.md
+++ b/README.md
@@ -193,7 +193,6 @@ It supports the standard `googletest` command line options.
 
 ### Assembler and disassembler
 
-* Support 16-bit floating point literals.
 * The disassembler could emit helpful annotations in comments.  For example:
   * Use variable name information from debug instructions to annotate
     key operations on variables.
index 9d8733b..f3d70b2 100644 (file)
@@ -698,6 +698,10 @@ std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
   return os;
 }
 
+// Parses a floating point number from the given stream and stores it into the
+// value parameter.
+// If the negate_value parameter is true then the number is negated before
+// it is stored into the value parameter.
 template <typename T, typename Traits>
 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
                                       HexFloat<T, Traits>& value) {
@@ -710,6 +714,26 @@ inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
   return is;
 }
 
+// Specialization of ParseNormalFloat for FloatProxy<Float16> values.
+// This will parse the float as it were a 32-bit floating point number,
+// and then round it down to fit into a Float16 value. 
+// The number is rounded towards zero.
+// Any floating point number that is too large will be rounded to +- infinity.
+template <>
+inline std::istream&
+ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
+    std::istream& is, bool negate_value,
+    HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
+  float f;
+  is >> f;
+  if (negate_value) {
+    f = -f;
+  }
+  HexFloat<FloatProxy<float>> float_val(f);
+  float_val.castTo(value, round_direction::kToZero);
+  return is;
+}
+
 // Reads a HexFloat from the given stream.
 // If the float is not encoded as a hex-float then it will be parsed
 // as a regular float.
@@ -940,6 +964,12 @@ std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
   }
   return os;
 }
+
+template<>
+inline std::ostream& operator<< <Float16>(std::ostream& os, const FloatProxy<Float16>& value) {
+  os << HexFloat<FloatProxy<Float16>>(value);
+  return os;
+}
 }
 
 #endif  // _LIBSPIRV_UTIL_HEX_FLOAT_H_
index a03e2a5..927914a 100644 (file)
@@ -233,9 +233,12 @@ void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst,
             stream_ << word;
             break;
           case SPV_NUMBER_FLOATING:
-            // Assume only 32-bit floats.
-            // TODO(dneto): Handle 16-bit floats also.
-            stream_ << spvutils::FloatProxy<float>(word);
+            if (operand.number_bit_width == 16) {
+              stream_ << spvutils::FloatProxy<spvutils::Float16>(uint16_t(word & 0xFFFF));
+            } else {
+              // Assume 32-bit floats.
+              stream_ << spvutils::FloatProxy<float>(word);
+            }
             break;
           default:
             assert(false && "Unreachable");
index 7d15161..adc492c 100644 (file)
@@ -387,9 +387,18 @@ spv_result_t AssemblyContext::binaryEncodeFloatingPointLiteral(
     spv_instruction_t* pInst) {
   const auto bit_width = assumedBitWidth(type);
   switch (bit_width) {
-    case 16:
-      return diagnostic(SPV_ERROR_INTERNAL)
-             << "Unsupported yet: 16-bit float constants.";
+    case 16: {
+      spvutils::HexFloat<FloatProxy<spvutils::Float16>> hVal(0);
+      if (auto error = parseNumber(val, error_code, &hVal,
+                                   "Invalid 16-bit float literal: "))
+        return error;
+      // getAsFloat will return the spvutils::Float16 value, and get_value
+      // will return a uint16_t representing the bits of the float.
+      // The encoding is therefore correct from the perspective of the SPIR-V
+      // spec since the top 16 bits will be 0.
+      return binaryEncodeU32(
+          static_cast<uint32_t>(hVal.value().getAsFloat().get_value()), pInst);
+    } break;
     case 32: {
       spvutils::HexFloat<FloatProxy<float>> fVal(0.0f);
       if (auto error = parseNumber(val, error_code, &fVal,
index 2d3a773..68c5c25 100644 (file)
@@ -183,7 +183,9 @@ class TextToBinaryTestBase : public T {
 };
 
 using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>;
-
 }  // namespace spvtest
 
+using RoundTripTest =
+    spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
+
 #endif  // LIBSPIRV_TEST_TEST_FIXTURE_H_
index 6654cbe..4c7190c 100644 (file)
@@ -341,18 +341,10 @@ INSTANTIATE_TEST_CASE_P(
     }));
 // clang-format on
 
-using RoundTripTest =
-    spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
-
 const int64_t kMaxUnsigned48Bit = (int64_t(1) << 48) - 1;
 const int64_t kMaxSigned48Bit = (int64_t(1) << 47) - 1;
 const int64_t kMinSigned48Bit = -kMaxSigned48Bit - 1;
 
-TEST_P(RoundTripTest, Sample) {
-  EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
-      << GetParam();
-}
-
 INSTANTIATE_TEST_CASE_P(
     OpConstantRoundTrip, RoundTripTest,
     ::testing::ValuesIn(std::vector<std::string>{
@@ -396,6 +388,41 @@ INSTANTIATE_TEST_CASE_P(
         "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -1.79769e+308\n",
     }));
 
+INSTANTIATE_TEST_CASE_P(
+    OpConstantHalfRoundTrip, RoundTripTest,
+    ::testing::ValuesIn(std::vector<std::string>{
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x0p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x0p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.1p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p-1\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.8p+1\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+1\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.1p+0\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p-1\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.8p+1\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+1\n",
+
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-16\n", // some denorms
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-24\n",
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p-24\n",
+
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+16\n", // +inf
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+16\n", // -inf
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -inf
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p+16\n", // nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.11p+16\n", // nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffp+16\n", // nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+16\n", // nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.004p+16\n", // nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.11p+16\n", // -nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffp+16\n", // -nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+16\n", // -nan
+        "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.004p+16\n", // -nan
+    }));
+
 // clang-format off
 // (Clang-format really wants to break up these strings across lines.
 INSTANTIATE_TEST_CASE_P(
index 5b53c88..0cd16d9 100644 (file)
@@ -272,13 +272,6 @@ INSTANTIATE_TEST_CASE_P(
         MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
     })));
 
-using RoundTripTest =
-    spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
-
-TEST_P(RoundTripTest, Sample) {
-  EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()));
-}
-
 INSTANTIATE_TEST_CASE_P(
     OpSwitchRoundTripUnsignedIntegers, RoundTripTest,
     ::testing::ValuesIn(std::vector<std::string>({
index d6f2033..6bd7513 100644 (file)
@@ -245,6 +245,38 @@ INSTANTIATE_TEST_CASE_P(
         {"!0xff800001", 0xff800001},  // NaN
     }));
 
+using TextToBinaryHalfValueTest = spvtest::TextToBinaryTestBase<
+    ::testing::TestWithParam<std::pair<std::string, uint32_t>>>;
+
+TEST_P(TextToBinaryHalfValueTest, Samples) {
+  const std::string input =
+      "%1 = OpTypeFloat 16\n%2 = OpConstant %1 " + GetParam().first;
+  EXPECT_THAT(CompiledInstructions(input),
+              Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 16}),
+                              MakeInstruction(SpvOpConstant,
+                                              {1, 2, GetParam().second})})));
+}
+
+INSTANTIATE_TEST_CASE_P(
+    HalfValues, TextToBinaryHalfValueTest,
+    ::testing::ValuesIn(std::vector<std::pair<std::string, uint32_t>>{
+        {"0.0", 0x00000000},
+        {"1.0", 0x00003c00},
+        {"1.000844", 0x00003c00}, // Truncate to 1.0
+        {"1.000977", 0x00003c01}, // Don't have to truncate
+        {"1.001465", 0x00003c01}, // Truncate to 1.0000977
+        {"1.5", 0x00003e00},
+        {"-1.0", 0x0000bc00},
+        {"2.0", 0x00004000},
+        {"-2.0", 0x0000c000},
+        {"0x1p1", 0x00004000},
+        {"-0x1p1", 0x0000c000},
+        {"0x1.8p1", 0x00004200},
+        {"0x1.8p4", 0x00004e00},
+        {"0x1.801p4", 0x00004e00},
+        {"0x1.804p4", 0x00004e01},
+    }));
+
 TEST(AssemblyContextParseNarrowSignedIntegers, Sample) {
   AssemblyContext context(AutoText(""), nullptr);
   const spv_result_t ec = SPV_FAILED_MATCH;
index c38c4bc..43fa26d 100644 (file)
@@ -27,6 +27,7 @@
 #include "UnitSPIRV.h"
 
 #include "gmock/gmock.h"
+#include "TestFixture.h"
 
 namespace {
 
@@ -56,4 +57,9 @@ TEST(WordVectorPrintTo, PreservesFlagsAndFill) {
   EXPECT_THAT(s.str(), Eq("xx10 0x0000000a 0x00000010 xx11"));
 }
 
+TEST_P(RoundTripTest, Sample) {
+  EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
+      << GetParam();
+}
+
 }  // anonymous namespace
index 3fbf3e5..5cec434 100644 (file)
@@ -215,5 +215,4 @@ inline std::string MakeLongUTF8String(size_t num_4_byte_chars) {
 }
 
 }  // namespace spvtest
-
 #endif  // LIBSPIRV_TEST_UNITSPIRV_H_