This adds half-precision constants to spirv-tools.
16-bit floats are always disassembled into hex-float format,
but can be assembled from floating point or hex-float inputs.
### Assembler and disassembler
-* Support 16-bit floating point literals.
* The disassembler could emit helpful annotations in comments. For example:
* Use variable name information from debug instructions to annotate
key operations on variables.
return os;
}
+// Parses a floating point number from the given stream and stores it into the
+// value parameter.
+// If the negate_value parameter is true then the number is negated before
+// it is stored into the value parameter.
template <typename T, typename Traits>
inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
HexFloat<T, Traits>& value) {
return is;
}
+// Specialization of ParseNormalFloat for FloatProxy<Float16> values.
+// This will parse the float as it were a 32-bit floating point number,
+// and then round it down to fit into a Float16 value.
+// The number is rounded towards zero.
+// Any floating point number that is too large will be rounded to +- infinity.
+template <>
+inline std::istream&
+ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
+ std::istream& is, bool negate_value,
+ HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
+ float f;
+ is >> f;
+ if (negate_value) {
+ f = -f;
+ }
+ HexFloat<FloatProxy<float>> float_val(f);
+ float_val.castTo(value, round_direction::kToZero);
+ return is;
+}
+
// Reads a HexFloat from the given stream.
// If the float is not encoded as a hex-float then it will be parsed
// as a regular float.
}
return os;
}
+
+template<>
+inline std::ostream& operator<< <Float16>(std::ostream& os, const FloatProxy<Float16>& value) {
+ os << HexFloat<FloatProxy<Float16>>(value);
+ return os;
+}
}
#endif // _LIBSPIRV_UTIL_HEX_FLOAT_H_
stream_ << word;
break;
case SPV_NUMBER_FLOATING:
- // Assume only 32-bit floats.
- // TODO(dneto): Handle 16-bit floats also.
- stream_ << spvutils::FloatProxy<float>(word);
+ if (operand.number_bit_width == 16) {
+ stream_ << spvutils::FloatProxy<spvutils::Float16>(uint16_t(word & 0xFFFF));
+ } else {
+ // Assume 32-bit floats.
+ stream_ << spvutils::FloatProxy<float>(word);
+ }
break;
default:
assert(false && "Unreachable");
spv_instruction_t* pInst) {
const auto bit_width = assumedBitWidth(type);
switch (bit_width) {
- case 16:
- return diagnostic(SPV_ERROR_INTERNAL)
- << "Unsupported yet: 16-bit float constants.";
+ case 16: {
+ spvutils::HexFloat<FloatProxy<spvutils::Float16>> hVal(0);
+ if (auto error = parseNumber(val, error_code, &hVal,
+ "Invalid 16-bit float literal: "))
+ return error;
+ // getAsFloat will return the spvutils::Float16 value, and get_value
+ // will return a uint16_t representing the bits of the float.
+ // The encoding is therefore correct from the perspective of the SPIR-V
+ // spec since the top 16 bits will be 0.
+ return binaryEncodeU32(
+ static_cast<uint32_t>(hVal.value().getAsFloat().get_value()), pInst);
+ } break;
case 32: {
spvutils::HexFloat<FloatProxy<float>> fVal(0.0f);
if (auto error = parseNumber(val, error_code, &fVal,
};
using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>;
-
} // namespace spvtest
+using RoundTripTest =
+ spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
+
#endif // LIBSPIRV_TEST_TEST_FIXTURE_H_
}));
// clang-format on
-using RoundTripTest =
- spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
-
const int64_t kMaxUnsigned48Bit = (int64_t(1) << 48) - 1;
const int64_t kMaxSigned48Bit = (int64_t(1) << 47) - 1;
const int64_t kMinSigned48Bit = -kMaxSigned48Bit - 1;
-TEST_P(RoundTripTest, Sample) {
- EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
- << GetParam();
-}
-
INSTANTIATE_TEST_CASE_P(
OpConstantRoundTrip, RoundTripTest,
::testing::ValuesIn(std::vector<std::string>{
"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -1.79769e+308\n",
}));
+INSTANTIATE_TEST_CASE_P(
+ OpConstantHalfRoundTrip, RoundTripTest,
+ ::testing::ValuesIn(std::vector<std::string>{
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x0p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x0p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.1p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p-1\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.8p+1\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+1\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.1p+0\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p-1\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.8p+1\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+1\n",
+
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-16\n", // some denorms
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-24\n",
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p-24\n",
+
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+16\n", // +inf
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+16\n", // -inf
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -inf
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p+16\n", // nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.11p+16\n", // nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffp+16\n", // nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+16\n", // nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.004p+16\n", // nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.11p+16\n", // -nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffp+16\n", // -nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+16\n", // -nan
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.004p+16\n", // -nan
+ }));
+
// clang-format off
// (Clang-format really wants to break up these strings across lines.
INSTANTIATE_TEST_CASE_P(
MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
})));
-using RoundTripTest =
- spvtest::TextToBinaryTestBase<::testing::TestWithParam<std::string>>;
-
-TEST_P(RoundTripTest, Sample) {
- EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()));
-}
-
INSTANTIATE_TEST_CASE_P(
OpSwitchRoundTripUnsignedIntegers, RoundTripTest,
::testing::ValuesIn(std::vector<std::string>({
{"!0xff800001", 0xff800001}, // NaN
}));
+using TextToBinaryHalfValueTest = spvtest::TextToBinaryTestBase<
+ ::testing::TestWithParam<std::pair<std::string, uint32_t>>>;
+
+TEST_P(TextToBinaryHalfValueTest, Samples) {
+ const std::string input =
+ "%1 = OpTypeFloat 16\n%2 = OpConstant %1 " + GetParam().first;
+ EXPECT_THAT(CompiledInstructions(input),
+ Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 16}),
+ MakeInstruction(SpvOpConstant,
+ {1, 2, GetParam().second})})));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ HalfValues, TextToBinaryHalfValueTest,
+ ::testing::ValuesIn(std::vector<std::pair<std::string, uint32_t>>{
+ {"0.0", 0x00000000},
+ {"1.0", 0x00003c00},
+ {"1.000844", 0x00003c00}, // Truncate to 1.0
+ {"1.000977", 0x00003c01}, // Don't have to truncate
+ {"1.001465", 0x00003c01}, // Truncate to 1.0000977
+ {"1.5", 0x00003e00},
+ {"-1.0", 0x0000bc00},
+ {"2.0", 0x00004000},
+ {"-2.0", 0x0000c000},
+ {"0x1p1", 0x00004000},
+ {"-0x1p1", 0x0000c000},
+ {"0x1.8p1", 0x00004200},
+ {"0x1.8p4", 0x00004e00},
+ {"0x1.801p4", 0x00004e00},
+ {"0x1.804p4", 0x00004e01},
+ }));
+
TEST(AssemblyContextParseNarrowSignedIntegers, Sample) {
AssemblyContext context(AutoText(""), nullptr);
const spv_result_t ec = SPV_FAILED_MATCH;
#include "UnitSPIRV.h"
#include "gmock/gmock.h"
+#include "TestFixture.h"
namespace {
EXPECT_THAT(s.str(), Eq("xx10 0x0000000a 0x00000010 xx11"));
}
+TEST_P(RoundTripTest, Sample) {
+ EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam()))
+ << GetParam();
+}
+
} // anonymous namespace
}
} // namespace spvtest
-
#endif // LIBSPIRV_TEST_UNITSPIRV_H_