Added Float16, and HexFloat conversions
authorAndrew Woloszyn <awoloszyn@google.com>
Thu, 3 Dec 2015 21:30:21 +0000 (16:30 -0500)
committerAndrew Woloszyn <awoloszyn@google.com>
Tue, 8 Dec 2015 19:41:57 +0000 (14:41 -0500)
include/util/hex_float.h
test/HexFloat.cpp

index a3bdc31..6f4b7c0 100644 (file)
 
 namespace spvutils {
 
+class Float16 {
+ public:
+  Float16(uint16_t v) : val(v) {}
+  Float16() = default;
+  static bool isNan(const Float16 val) {
+    return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0);
+  }
+  Float16(const Float16& other) { val = other.val; }
+  uint16_t get_value() const { return val; }
+
+ private:
+  uint16_t val;
+};
+
+// To specialize this type, you must override uint_type to define
+// an unsigned integer that can fit your floating point type.
+// You must also add a isNan function that returns true if
+// a value is Nan.
 template <typename T>
 struct FloatProxyTraits {
-  typedef void uint_type;
+  using uint_type = void;
 };
 
 template <>
 struct FloatProxyTraits<float> {
-  typedef uint32_t uint_type;
+  using uint_type = uint32_t;
+  static bool isNan(float f) { return std::isnan(f); }
 };
 
 template <>
 struct FloatProxyTraits<double> {
-  typedef uint64_t uint_type;
+  using uint_type = uint64_t;
+  static bool isNan(double f) { return std::isnan(f); }
+};
+
+template <>
+struct FloatProxyTraits<Float16> {
+  using uint_type = uint16_t;
+  static bool isNan(Float16 f) { return Float16::isNan(f); }
 };
 
 // Since copying a floating point number (especially if it is NaN)
@@ -86,7 +112,7 @@ class FloatProxy {
   uint_type data() const { return data_; }
 
   // Returns true if the value represents any type of NaN.
-  bool isNan() { return std::isnan(getAsFloat()); }
+  bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
 
  private:
   uint_type data_;
@@ -111,9 +137,13 @@ std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
 template <typename T>
 struct HexFloatTraits {
   // Integer type that can store this hex-float.
-  typedef void uint_type;
+  using uint_type = void;
   // Signed integer type that can store this hex-float.
-  typedef void int_type;
+  using int_type = void;
+  // The numerical type that this HexFloat represents.
+  using underlying_type = void;
+  // The type needed to construct the underlying type.
+  using native_type = void;
   // The number of bits that are actually relevant in the uint_type.
   // This allows us to deal with, for example, 24-bit values in a 32-bit
   // integer.
@@ -131,8 +161,10 @@ struct HexFloatTraits {
 // 1 sign bit, 8 exponent bits, 23 fractional bits.
 template <>
 struct HexFloatTraits<FloatProxy<float>> {
-  typedef uint32_t uint_type;
-  typedef int32_t int_type;
+  using uint_type = uint32_t;
+  using int_type = int32_t;
+  using underlying_type = FloatProxy<float>;
+  using native_type = float;
   static const uint_type num_used_bits = 32;
   static const uint_type num_exponent_bits = 8;
   static const uint_type num_fraction_bits = 23;
@@ -143,14 +175,38 @@ struct HexFloatTraits<FloatProxy<float>> {
 // 1 sign bit, 11 exponent bits, 52 fractional bits.
 template <>
 struct HexFloatTraits<FloatProxy<double>> {
-  typedef uint64_t uint_type;
-  typedef int64_t int_type;
+  using uint_type = uint64_t;
+  using int_type = int64_t;
+  using underlying_type = FloatProxy<double>;
+  using native_type = double;
   static const uint_type num_used_bits = 64;
   static const uint_type num_exponent_bits = 11;
   static const uint_type num_fraction_bits = 52;
   static const uint_type exponent_bias = 1023;
 };
 
+// Traits for IEEE half.
+// 1 sign bit, 5 exponent bits, 10 fractional bits.
+template <>
+struct HexFloatTraits<FloatProxy<Float16>> {
+  using uint_type = uint16_t;
+  using int_type = int16_t;
+  using underlying_type = uint16_t;
+  using native_type = uint16_t;
+  static const uint_type num_used_bits = 16;
+  static const uint_type num_exponent_bits = 5;
+  static const uint_type num_fraction_bits = 10;
+  static const uint_type exponent_bias = 15;
+};
+
+enum class round_direction {
+  kToZero,
+  kToNearestEven,
+  kToPositiveInfinity,
+  kToNegativeInfinity,
+  max = kToNegativeInfinity
+};
+
 // Template class that houses a floating pointer number.
 // It exposes a number of constants based on the provided traits to
 // assist in interpreting the bits of the value.
@@ -159,6 +215,8 @@ class HexFloat {
  public:
   using uint_type = typename Traits::uint_type;
   using int_type = typename Traits::int_type;
+  using underlying_type = typename Traits::underlying_type;
+  using native_type = typename Traits::native_type;
 
   explicit HexFloat(T f) : value_(f) {}
 
@@ -190,10 +248,15 @@ class HexFloat {
       spvutils::SetBits<uint_type, 0,
                         num_fraction_bits + num_overflow_bits>::get;
 
-  // The topmost bit in the fraction. (The first non-implicit bit).
+  // The topmost bit in the nibble-aligned fraction.
   static const uint_type fraction_top_bit =
       uint_type(1) << (num_fraction_bits + num_overflow_bits - 1);
 
+  // The least significant bit in the exponent, which is also the bit
+  // immediately to the left of the significand.
+  static const uint_type first_exponent_bit = uint_type(1)
+                                              << (num_fraction_bits);
+
   // The mask for the encoded fraction. It does not include the
   // implicit bit.
   static const uint_type fraction_encode_mask =
@@ -213,12 +276,334 @@ class HexFloat {
   static const uint32_t fraction_right_shift =
       (sizeof(uint_type) * 8) - num_fraction_bits;
 
+  // The maximum representable unbiased exponent.
+  static const int_type max_exponent =
+      (exponent_mask >> num_fraction_bits) - exponent_bias;
+  // The minimum representable exponent for normalized numbers.
+  static const int_type min_exponent = -static_cast<int_type>(exponent_bias);
+
+  // Returns the bits associated with the value.
+  uint_type getBits() const { return spvutils::BitwiseCast<uint_type>(value_); }
+
+  // Returns the bits associated with the value, without the leading sign bit.
+  uint_type getUnsignedBits() const {
+    return spvutils::BitwiseCast<uint_type>(value_) & ~sign_mask;
+  }
+
+  // Returns the bits associated with the exponent, shifted to start at the
+  // lsb of the type.
+  const uint_type getExponentBits() const {
+    return (getBits() & exponent_mask) >> num_fraction_bits;
+  }
+
+  // Returns the exponent in unbiased form. This is the exponent in the
+  // human-friendly form.
+  const int_type getUnbiasedExponent() const {
+    return (static_cast<int_type>(getExponentBits()) - exponent_bias);
+  }
+
+  // Returns just the significand bits from the value.
+  const uint_type getSignificandBits() const {
+    return getBits() & fraction_encode_mask;
+  }
+
+  // If the number was normalized, returns the unbiased exponent.
+  // If the number was denormal, normalize the exponent first.
+  const int_type getUnbiasedNormalizedExponent() const {
+    if ((getBits() & ~sign_mask) == 0) {  // special case if everything is 0
+      return 0;
+    }
+    int_type exp = getUnbiasedExponent();
+    if (exp == min_exponent) {  // We are in denorm land.
+      uint_type significand_bits = getSignificandBits();
+      while ((significand_bits & (first_exponent_bit >> 1)) == 0) {
+        significand_bits <<= 1;
+        exp -= 1;
+      }
+      significand_bits &= fraction_encode_mask;
+    }
+    return exp;
+  }
+
+  // Returns the signficand after it has been normalized.
+  const uint_type getNormalizedSignificand() const {
+    int_type unbiased_exponent = getUnbiasedNormalizedExponent();
+    uint_type significand = getSignificandBits();
+    for (int_type i = unbiased_exponent; i <= min_exponent; ++i) {
+      significand <<= 1;
+    }
+    significand &= fraction_encode_mask;
+    return significand;
+  }
+
+  // Returns true if this number represents a negative value.
+  bool isNegative() const { return (getBits() & sign_mask) != 0; }
+
+  // Sets this HexFloat from the individual components.
+  // Note this assumes EVERY significand is normalized, and has an implicit
+  // leading one. This means that the only way that this method will set 0,
+  // is if you set a number so denormalized that it underflows.
+  // Do not use this method with raw bits extracted from a subnormal number,
+  // since subnormals do not have an implicit leading 1 in the significand.
+  // The significand is also expected to be in the
+  // lowest-most num_fraction_bits of the uint_type.
+  // The exponent is expected to be unbiased, meaning an exponent of
+  // 0 actually means 0.
+  // If underflow_round_up is set, then on underflow, if a number is non-0
+  // and would underflow, we round up to the smallest denorm.
+  void setFromSignUnbiasedExponentAndNormalizedSignificand(
+      bool negative, int_type exponent, uint_type significand,
+      bool round_denorm_up) {
+    bool significand_is_zero = significand == 0;
+
+    if (exponent <= min_exponent) {
+      // If this was denormalized, then we have to shift the bit on, meaning
+      // the significand is not zero.
+      significand_is_zero = false;
+      significand |= first_exponent_bit;
+      significand >>= 1;
+    }
+
+    while (exponent < min_exponent) {
+      significand >>= 1;
+      ++exponent;
+    }
+
+    if (exponent == min_exponent) {
+      if (significand == 0 && !significand_is_zero && round_denorm_up) {
+        significand = 0x1;
+      }
+    }
+
+    uint_type value = 0;
+    if (negative) {
+      value |= sign_mask;
+    }
+    exponent += exponent_bias;
+    assert(exponent >= 0);
+
+    // put it all together
+    exponent = (exponent << exponent_left_shift) & exponent_mask;
+    significand &= fraction_encode_mask;
+    value |= exponent | significand;
+    value_ = BitwiseCast<T>(value);
+  }
+
+  // Increments the significand of this number by the given amount.
+  // If this would spill the significand into the implicit bit,
+  // carry is set to true and the significand is shifted to fit into
+  // the correct location, otherwise carry is set to false.
+  // All significands and to_increment are assumed to be within the bounds
+  // for a valid significand.
+  static uint_type incrementSignificand(uint_type significand,
+                                        uint_type to_increment, bool* carry) {
+    significand += to_increment;
+    *carry = false;
+    if (significand & first_exponent_bit) {
+      *carry = true;
+      // The implicit 1-bit will have carried, so we should zero-out the
+      // top bit and shift back.
+      significand &= ~first_exponent_bit;
+      significand >>= 1;
+    }
+    return significand;
+  }
+
+  // These exist because MSVC throws warnings on negative right-shifts
+  // even if they are not going to be executed. Eg:
+  // constant_number < 0? 0: constant_number
+  // These convert the negative left-shifts into right shifts.
+
+  template <int_type N, typename enable = void>
+  struct negatable_left_shift {
+    static uint_type val(uint_type val) { return val >> -N; }
+  };
+
+  template <int_type N>
+  struct negatable_left_shift<N, typename std::enable_if<N >= 0>::type> {
+    static uint_type val(uint_type val) { return val << N; }
+  };
+
+  template <int_type N, typename enable = void>
+  struct negatable_right_shift {
+    static uint_type val(uint_type val) { return val << -N; }
+  };
+
+  template <int_type N>
+  struct negatable_right_shift<N, typename std::enable_if<N >= 0>::type> {
+    static uint_type val(uint_type val) { return val >> N; }
+  };
+
+  // Returns the significand, rounded to fit in a significand in
+  // other_T. This is shifted so that the most significant
+  // bit of the rounded number lines up with the most significant bit
+  // of the returned significand.
+  template <typename other_T>
+  typename other_T::uint_type getRoundedNormalizedSignificand(
+      round_direction dir, bool* carry_bit) {
+    using other_uint_type = typename other_T::uint_type;
+    static const int_type num_throwaway_bits =
+        static_cast<int_type>(num_fraction_bits) -
+        static_cast<int_type>(other_T::num_fraction_bits);
+
+    static const uint_type last_significant_bit =
+        (num_throwaway_bits < 0)
+            ? 0
+            : negatable_left_shift<num_throwaway_bits>::val(1u);
+    static const uint_type first_rounded_bit =
+        (num_throwaway_bits < 1)
+            ? 0
+            : negatable_left_shift<num_throwaway_bits - 1>::val(1u);
+
+    static const uint_type throwaway_mask_bits =
+        num_throwaway_bits > 0 ? num_throwaway_bits : 0;
+    static const uint_type throwaway_mask =
+        spvutils::SetBits<uint_type, 0, throwaway_mask_bits>::get;
+
+    *carry_bit = false;
+    other_uint_type out_val = 0;
+    uint_type significand = getNormalizedSignificand();
+    // If we are up-casting, then we just have to shift to the right location.
+    if (num_throwaway_bits <= 0) {
+      out_val = significand;
+      uint_type shift_amount = -num_throwaway_bits;
+      out_val <<= shift_amount;
+      return out_val;
+    }
+
+    // If every non-representable bit is 0, then we don't have any casting to
+    // do.
+    if ((significand & throwaway_mask) == 0) {
+      return static_cast<other_uint_type>(
+          negatable_right_shift<num_throwaway_bits>::val(significand));
+    }
+
+    bool round_away_from_zero = false;
+    // We actually have to narrow the significand here, so we have to follow the
+    // rounding rules.
+    switch (dir) {
+      case round_direction::kToZero:
+          break;
+      case round_direction::kToPositiveInfinity:
+        round_away_from_zero = !isNegative();
+        break;
+      case round_direction::kToNegativeInfinity:
+        round_away_from_zero = isNegative();
+        break;
+      case round_direction::kToNearestEven:
+        // Have to round down, round bit is 0
+        if ((first_rounded_bit & significand) == 0) {
+            break;
+        }
+        if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) {
+          // If any subsequent bit of the rounded portion is non-0 then we round
+          // up.
+          round_away_from_zero = true;
+          break;
+        }
+        // We are exactly half-way between 2 numbers, pick even.
+        if ((significand & last_significant_bit) != 0) {
+          // 1 for our last bit, round up.
+          round_away_from_zero = true;
+          break;
+        }
+        break;
+    }
+
+    if (round_away_from_zero) {
+      return static_cast<other_uint_type>(
+          negatable_right_shift<num_throwaway_bits>::val(incrementSignificand(
+              significand, last_significant_bit, carry_bit)));
+    } else {
+      return static_cast<other_uint_type>(
+          negatable_right_shift<num_throwaway_bits>::val(significand));
+    }
+    // We really shouldn't get here.
+    assert(false && "We should not have ended up here");
+    return 0;
+  }
+
+  // Casts this value to another HexFloat. If the cast is widening,
+  // then round_dir is ignored. If the cast is narrowing, then
+  // the result is rounded in the direction specified.
+  // This number will retain Nan and Inf values.
+  // It will also saturate to Inf if the number overflows, and
+  // underflow to (0 or min depending on rounding) if the number underflows.
+  template <typename other_T>
+  void castTo(other_T& other, round_direction round_dir) {
+    other = other_T(static_cast<typename other_T::native_type>(0));
+    bool negate = isNegative();
+    if (getUnsignedBits() == 0) {
+      if (negate) {
+        other.set_value(-other.value());
+      }
+      return;
+    }
+    uint_type significand = getSignificandBits();
+    bool carried = false;
+    typename other_T::uint_type rounded_significand =
+        getRoundedNormalizedSignificand<other_T>(round_dir, &carried);
+
+    int_type exponent = getUnbiasedExponent();
+    if (exponent == min_exponent) {
+      // If we are denormal, normalize the exponent, so that we can encode
+      // easily.
+      exponent += 1;
+      for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0;
+           check_bit >>= 1) {
+        exponent -= 1;
+        if (check_bit & significand) break;
+      }
+    }
+
+    bool is_nan =
+        (getBits() & exponent_mask) == exponent_mask && significand != 0;
+    bool is_inf =
+        !is_nan &&
+        ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
+         (significand == 0 && (getBits() & exponent_mask) == exponent_mask));
+
+    // If we are Nan or Inf we should pass that through.
+    if (is_inf) {
+      other.set_value(BitwiseCast<typename other_T::underlying_type>(
+          static_cast<typename other_T::uint_type>(
+              (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
+      return;
+    }
+    if (is_nan) {
+      typename other_T::uint_type shifted_significand;
+      shifted_significand = static_cast<typename other_T::uint_type>(
+          negatable_left_shift<other_T::num_fraction_bits -
+                               num_fraction_bits>::val(significand));
+
+      // We are some sort of Nan. We try to keep the bit-pattern of the Nan
+      // as close as possible. If we had to shift off bits so we are 0, then we
+      // just set the last bit.
+      other.set_value(BitwiseCast<typename other_T::underlying_type>(
+          static_cast<typename other_T::uint_type>(
+              (negate ? other_T::sign_mask : 0) | other_T::exponent_mask |
+              (shifted_significand == 0 ? 0x1 : shifted_significand))));
+      return;
+    }
+
+    bool round_underflow_up =
+        isNegative() ? round_dir == round_direction::kToNegativeInfinity
+                     : round_dir == round_direction::kToPositiveInfinity;
+
+    // setFromSignUnbiasedExponentAndNormalizedSignificand will
+    // zero out any underflowing value (but retain the sign).
+    other.setFromSignUnbiasedExponentAndNormalizedSignificand(
+        negate, exponent, rounded_significand, round_underflow_up);
+    return;
+  }
+
  private:
   T value_;
 
   static_assert(num_used_bits ==
                     Traits::num_exponent_bits + Traits::num_fraction_bits + 1,
                 "The number of bits do not fit");
+  static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match");
 };
 
 // Returns 4 bits represented by the hex character.
index 3fdea66..8f3a36d 100644 (file)
@@ -528,5 +528,470 @@ INSTANTIATE_TEST_CASE_P(
 
         })));
 
+// double is used so that unbiased_exponent can be used with the output
+// of ldexp directly.
+int32_t unbiased_exponent(double f) {
+  return spvutils::HexFloat<spvutils::FloatProxy<float>>(
+      static_cast<float>(f)).getUnbiasedNormalizedExponent();
+}
+
+int16_t unbiased_half_exponent(uint16_t f) {
+  return spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>(f)
+      .getUnbiasedNormalizedExponent();
+}
+
+TEST(HexFloatOperationTest, UnbiasedExponent) {
+  // Float cases
+  EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, 0)));
+  EXPECT_EQ(-32, unbiased_exponent(ldexp(1.0f, -32)));
+  EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42)));
+  EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125)));
+  // Saturates to 128
+  EXPECT_EQ(128, unbiased_exponent(ldexp(1.0f, 256)));
+
+  EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100)));
+  EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm
+  EXPECT_EQ(-128, unbiased_exponent(ldexp(1.0f, -128)));
+  EXPECT_EQ(-129, unbiased_exponent(ldexp(1.0f, -129)));
+  EXPECT_EQ(-140, unbiased_exponent(ldexp(1.0f, -140)));
+  // Smallest representable number
+  EXPECT_EQ(-126 - 23, unbiased_exponent(ldexp(1.0f, -126 - 23)));
+  // Should get rounded to 0 first.
+  EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, -127 - 23)));
+
+  // Float16 cases
+  // The exponent is represented in the bits 0x7C00
+  // The offset is -15
+  EXPECT_EQ(0, unbiased_half_exponent(0x3C00));
+  EXPECT_EQ(3, unbiased_half_exponent(0x4800));
+  EXPECT_EQ(-1, unbiased_half_exponent(0x3800));
+  EXPECT_EQ(-14, unbiased_half_exponent(0x0400));
+  EXPECT_EQ(16, unbiased_half_exponent(0x7C00));
+  EXPECT_EQ(10, unbiased_half_exponent(0x6400));
+
+  // Smallest representable number
+  EXPECT_EQ(-24, unbiased_half_exponent(0x0001));
+}
+
+// Creates a float that is the sum of 1/(2 ^ fractions[i]) for i in factions
+float float_fractions(const std::vector<uint32_t>& fractions) {
+  float f = 0;
+  for(int32_t i: fractions) {
+    f += ldexp(1.0f, -i);
+  }
+  return f;
+}
+
+// Returns the normalized significand of a HexFloat<FloatProxy<float>>
+// that was created by calling float_fractions with the input fractions,
+// raised to the power of exp.
+uint32_t normalized_significand(const std::vector<uint32_t>& fractions, uint32_t exp) {
+  return spvutils::HexFloat<spvutils::FloatProxy<float>>(
+             static_cast<float>(ldexp(float_fractions(fractions), exp)))
+      .getNormalizedSignificand();
+}
+
+// Sets the bits from MSB to LSB of the significand part of a float.
+// For example 0 would set the bit 23 (counting from LSB to MSB),
+// and 1 would set the 22nd bit.
+uint32_t bits_set(const std::vector<uint32_t>& bits) {
+  const uint32_t top_bit = 1u << 22u;
+  uint32_t val= 0;
+  for(uint32_t i: bits) {
+    val |= top_bit >> i;
+  }
+  return val;
+}
+
+// The same as bits_set but for a Float16 value instead of 32-bit floating
+// point.
+uint16_t half_bits_set(const std::vector<uint32_t>& bits) {
+  const uint32_t top_bit = 1u << 9u;
+  uint32_t val= 0;
+  for(uint32_t i: bits) {
+    val |= top_bit >> i;
+  }
+  return val;
+}
+
+TEST(HexFloatOperationTest, NormalizedSignificand) {
+  // For normalized numbers (the following) it should be a simple matter
+  // of getting rid of the top implicit bit
+  EXPECT_EQ(bits_set({}), normalized_significand({0}, 0));
+  EXPECT_EQ(bits_set({0}), normalized_significand({0, 1}, 0));
+  EXPECT_EQ(bits_set({0, 1}), normalized_significand({0, 1, 2}, 0));
+  EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 0));
+  EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 32));
+  EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 126));
+
+  // For denormalized numbers we expect the normalized significand to
+  // shift as if it were normalized. This means, in practice that the
+  // top_most set bit will be cut off. Looks very similar to above (on purpose)
+  EXPECT_EQ(bits_set({}), normalized_significand({0}, -127));
+  EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -128));
+  EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -127));
+  EXPECT_EQ(bits_set({}), normalized_significand({22}, -127));
+  EXPECT_EQ(bits_set({0}), normalized_significand({21, 22}, -127));
+}
+
+// Returns the 32-bit floating point value created by
+// calling setFromSignUnbiasedExponentAndNormalizedSignificand
+// on a HexFloat<FloatProxy<float>>
+float set_from_sign(bool negative, int32_t unbiased_exponent,
+                   uint32_t significand, bool round_denorm_up) {
+  spvutils::HexFloat<spvutils::FloatProxy<float>>  f(0.f);
+  f.setFromSignUnbiasedExponentAndNormalizedSignificand(
+      negative, unbiased_exponent, significand, round_denorm_up);
+  return f.value().getAsFloat();
+}
+
+TEST(HexFloatOperationTests,
+     SetFromSignUnbiasedExponentAndNormalizedSignificand) {
+
+  EXPECT_EQ(1.f, set_from_sign(false, 0, 0, false));
+
+  // Tests insertion of various denormalized numbers with and without round up.
+  EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, false));
+  EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, true));
+  EXPECT_EQ(0.f, set_from_sign(false, -150, 1, false));
+  EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -150, 1, true));
+
+  EXPECT_EQ(ldexp(1.0f, -127), set_from_sign(false, -127, 0, false));
+  EXPECT_EQ(ldexp(1.0f, -128), set_from_sign(false, -128, 0, false));
+  EXPECT_EQ(float_fractions({0, 1, 2, 5}),
+            set_from_sign(false, 0, bits_set({0, 1, 4}), false));
+  EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -32),
+            set_from_sign(false, -32, bits_set({0, 1, 4}), false));
+  EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -128),
+            set_from_sign(false, -128, bits_set({0, 1, 4}), false));
+
+  // The negative cases from above.
+  EXPECT_EQ(-1.f, set_from_sign(true, 0, 0, false));
+  EXPECT_EQ(-ldexp(1.0, -127), set_from_sign(true, -127, 0, false));
+  EXPECT_EQ(-ldexp(1.0, -128), set_from_sign(true, -128, 0, false));
+  EXPECT_EQ(-float_fractions({0, 1, 2, 5}),
+            set_from_sign(true, 0, bits_set({0, 1, 4}), false));
+  EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -32),
+            set_from_sign(true, -32, bits_set({0, 1, 4}), false));
+  EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -128),
+            set_from_sign(true, -128, bits_set({0, 1, 4}), false));
+}
+
+TEST(HexFloatOperationTests, NonRounding) {
+  // Rounding from 32-bit hex-float to 32-bit hex-float should be trivial,
+  // except in the denorm case which is a bit more complex.
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  bool carry_bit = false;
+
+  spvutils::round_direction rounding[] = {
+      spvutils::round_direction::kToZero,
+      spvutils::round_direction::kToNearestEven,
+      spvutils::round_direction::kToPositiveInfinity,
+      spvutils::round_direction::kToNegativeInfinity};
+
+  // Everything fits, so this should be straight-forward
+  for (spvutils::round_direction round : rounding) {
+    EXPECT_EQ(bits_set({}), HF(0.f).getRoundedNormalizedSignificand<HF>(
+                                round, &carry_bit));
+    EXPECT_FALSE(carry_bit);
+
+    EXPECT_EQ(bits_set({0}),
+              HF(float_fractions({0, 1}))
+                  .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
+    EXPECT_FALSE(carry_bit);
+
+    EXPECT_EQ(bits_set({1, 3}),
+              HF(float_fractions({0, 2, 4}))
+                  .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
+    EXPECT_FALSE(carry_bit);
+
+    EXPECT_EQ(
+        bits_set({0, 1, 4}),
+        HF(static_cast<float>(-ldexp(float_fractions({0, 1, 2, 5}), -128)))
+            .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
+    EXPECT_FALSE(carry_bit);
+
+    EXPECT_EQ(
+        bits_set({0, 1, 4, 22}),
+        HF(static_cast<float>(float_fractions({0, 1, 2, 5, 23})))
+            .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
+    EXPECT_FALSE(carry_bit);
+  }
+}
+
+using RD = spvutils::round_direction;
+struct RoundSignificandCase {
+  float source_float;
+  std::pair<int16_t, bool> expected_results;
+  spvutils::round_direction round;
+};
+
+using HexFloatRoundTest =
+    ::testing::TestWithParam<RoundSignificandCase>;
+
+TEST_P(HexFloatRoundTest, RoundDownToFP16) {
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
+
+  HF input_value(GetParam().source_float);
+  bool carry_bit = false;
+  EXPECT_EQ(GetParam().expected_results.first,
+            input_value.getRoundedNormalizedSignificand<HF16>(
+                GetParam().round, &carry_bit));
+  EXPECT_EQ(carry_bit, GetParam().expected_results.second);
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatRoundTest,
+  ::testing::ValuesIn(std::vector<RoundSignificandCase>(
+  {
+    {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToZero},
+    {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNearestEven},
+    {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToPositiveInfinity},
+    {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+
+    {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+    {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
+    {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNearestEven},
+
+    {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToZero},
+    {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToPositiveInfinity},
+    {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToNearestEven},
+
+    {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+    {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
+    {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
+
+    {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+    {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToPositiveInfinity},
+    {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity},
+    {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
+
+    {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+    {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
+    {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
+
+    // Carries
+    {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToZero},
+    {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToPositiveInfinity},
+    {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToNegativeInfinity},
+    {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToNearestEven},
+
+    // Cases where original number was denorm. Note: this should have no effect
+    // the number is pre-normalized.
+    {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -128)), std::make_pair(half_bits_set({0}), false), RD::kToZero},
+    {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -129)), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
+    {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -131)), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
+    {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -130)), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
+   })));
+// clang-format on
+
+struct UpCastSignificandCase {
+  uint16_t source_half;
+  uint32_t expected_result;
+};
+
+using HexFloatRoundUpSignificandTest =
+    ::testing::TestWithParam<UpCastSignificandCase>;
+TEST_P(HexFloatRoundUpSignificandTest, Widening) {
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
+  bool carry_bit = false;
+
+  spvutils::round_direction rounding[] = {
+      spvutils::round_direction::kToZero,
+      spvutils::round_direction::kToNearestEven,
+      spvutils::round_direction::kToPositiveInfinity,
+      spvutils::round_direction::kToNegativeInfinity};
+
+  // Everything fits, so everything should just be bit-shifts.
+  for (spvutils::round_direction round : rounding) {
+    carry_bit = false;
+    HF16 input_value(GetParam().source_half);
+    EXPECT_EQ(
+        GetParam().expected_result,
+        input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit))
+        << std::hex << "0x"
+        << input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit)
+        << "  0x" << GetParam().expected_result;
+    EXPECT_FALSE(carry_bit);
+  }
+}
+
+INSTANTIATE_TEST_CASE_P(F16toF32, HexFloatRoundUpSignificandTest,
+  // 0xFC00 of the source 16-bit hex value cover the sign and the exponent.
+  // They are ignored for this test.
+  ::testing::ValuesIn(std::vector<UpCastSignificandCase>(
+  {
+    {0x3F00, 0x600000},
+    {0x0F00, 0x600000},
+    {0x0F01, 0x602000},
+    {0x0FFF, 0x7FE000},
+  })));
+
+struct DownCastTest {
+  float source_float;
+  uint16_t expected_half;
+  std::vector<spvutils::round_direction> directions;
+};
+
+std::string get_round_text(spvutils::round_direction direction) {
+#define CASE(round_direction) \
+  case round_direction:      \
+    return #round_direction
+
+  switch (direction) {
+    CASE(spvutils::round_direction::kToZero);
+    CASE(spvutils::round_direction::kToPositiveInfinity);
+    CASE(spvutils::round_direction::kToNegativeInfinity);
+    CASE(spvutils::round_direction::kToNearestEven);
+  }
+#undef CASE
+  return "";
+}
+
+using HexFloatFP32To16Tests = ::testing::TestWithParam<DownCastTest>;
+
+TEST_P(HexFloatFP32To16Tests, NarrowingCasts) {
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
+  HF f(GetParam().source_float);
+  for (auto round : GetParam().directions) {
+    HF16 half(0);
+    f.castTo(half, round);
+    EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value())
+        << get_round_text(round) << "  " << std::hex
+        << spvutils::BitwiseCast<uint32_t>(GetParam().source_float)
+        << " cast to: " << half.value().getAsFloat().get_value();
+  }
+}
+
+const uint16_t positive_infinity = 0x7C00;
+const uint16_t negative_infinity = 0xFC00;
+
+INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatFP32To16Tests,
+  ::testing::ValuesIn(std::vector<DownCastTest>(
+  {
+    // Exactly representable as half.
+    {0.f, 0x0, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {-0.f, 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {1.0f, 0x3C00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {-1.0f, 0xBC00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+
+    {float_fractions({0, 1, 10}) , 0x3E01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {-float_fractions({0, 1, 10}) , 0xBE01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(ldexp(float_fractions({0, 1, 10}), 3)), 0x4A01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(-ldexp(float_fractions({0, 1, 10}), 3)), 0xCA01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+
+
+    // Underflow
+    {static_cast<float>(ldexp(1.0f, -25)), 0x0, {RD::kToZero, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(ldexp(1.0f, -25)), 0x1, {RD::kToPositiveInfinity}},
+    {static_cast<float>(-ldexp(1.0f, -25)), 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNearestEven}},
+    {static_cast<float>(-ldexp(1.0f, -25)), 0x8001, {RD::kToNegativeInfinity}},
+    {static_cast<float>(ldexp(1.0f, -24)), 0x1, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+
+    // Overflow
+    {static_cast<float>(ldexp(1.0f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(ldexp(1.0f, 18)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(ldexp(1.3f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(-ldexp(1.0f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(-ldexp(1.0f, 18)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {static_cast<float>(-ldexp(1.3f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+
+    // Transfer of Infinities
+    {std::numeric_limits<float>::infinity(), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+    {-std::numeric_limits<float>::infinity(), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
+
+    // Nans are below because we cannot test for equality.
+  })));
+
+struct UpCastCase{
+  uint16_t source_half;
+  float expected_float;
+};
+
+using HexFloatFP16To32Tests = ::testing::TestWithParam<UpCastCase>;
+TEST_P(HexFloatFP16To32Tests, WideningCasts) {
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
+  HF16 f(GetParam().source_half);
+
+  spvutils::round_direction rounding[] = {
+      spvutils::round_direction::kToZero,
+      spvutils::round_direction::kToNearestEven,
+      spvutils::round_direction::kToPositiveInfinity,
+      spvutils::round_direction::kToNegativeInfinity};
+
+  // Everything fits, so everything should just be bit-shifts.
+  for (spvutils::round_direction round : rounding) {
+    HF flt(0.f);
+    f.castTo(flt, round);
+    EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat())
+        << get_round_text(round) << "  " << std::hex
+        << spvutils::BitwiseCast<uint16_t>(GetParam().source_half)
+        << " cast to: " << flt.value().getAsFloat();
+  }
+}
+
+INSTANTIATE_TEST_CASE_P(F16ToF32, HexFloatFP16To32Tests,
+  ::testing::ValuesIn(std::vector<UpCastCase>(
+  {
+    {0x0000, 0.f},
+    {0x8000, -0.f},
+    {0x3C00, 1.0f},
+    {0xBC00, -1.0f},
+    {0x3F00, float_fractions({0, 1, 2})},
+    {0xBF00, -float_fractions({0, 1, 2})},
+    {0x3F01, float_fractions({0, 1, 2, 10})},
+    {0xBF01, -float_fractions({0, 1, 2, 10})},
+
+    // denorm
+    {0x0001, static_cast<float>(ldexp(1.0, -24))},
+    {0x0002, static_cast<float>(ldexp(1.0, -23))},
+    {0x8001, static_cast<float>(-ldexp(1.0, -24))},
+    {0x8011, static_cast<float>(-ldexp(1.0, -20) + -ldexp(1.0, -24))},
+
+    // inf
+    {0x7C00, std::numeric_limits<float>::infinity()},
+    {0xFC00, -std::numeric_limits<float>::infinity()},
+  })));
+
+TEST(HexFloatOperationTests, NanTests) {
+  using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
+  using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
+  spvutils::round_direction rounding[] = {
+      spvutils::round_direction::kToZero,
+      spvutils::round_direction::kToNearestEven,
+      spvutils::round_direction::kToPositiveInfinity,
+      spvutils::round_direction::kToNegativeInfinity};
+
+  // Everything fits, so everything should just be bit-shifts.
+  for (spvutils::round_direction round : rounding) {
+    HF16 f16(0);
+    HF f(0.f);
+    HF(std::numeric_limits<float>::quiet_NaN()).castTo(f16, round);
+    EXPECT_TRUE(f16.value().isNan());
+    HF(std::numeric_limits<float>::signaling_NaN()).castTo(f16, round);
+    EXPECT_TRUE(f16.value().isNan());
+
+    HF16(0x7C01).castTo(f, round);
+    EXPECT_TRUE(f.value().isNan());
+    HF16(0x7C11).castTo(f, round);
+    EXPECT_TRUE(f.value().isNan());
+    HF16(0xFC01).castTo(f, round);
+    EXPECT_TRUE(f.value().isNan());
+    HF16(0x7C10).castTo(f, round);
+    EXPECT_TRUE(f.value().isNan());
+    HF16(0xFF00).castTo(f, round);
+    EXPECT_TRUE(f.value().isNan());
+  }
+}
+
 // TODO(awoloszyn): Add fp16 tests and HexFloatTraits.
 }