Replace at::Half non-vectorized conversions with implementations from FP16 (#14411)
authorChandler Zuo <chandlerzuo@fb.com>
Tue, 4 Dec 2018 22:23:22 +0000 (14:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 22:32:33 +0000 (14:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14411
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14579

Folded the fp16 codes into c10.

Reviewed By: ezyang

Differential Revision: D13206450

fbshipit-source-id: 472208dd230dc49d33935622ff3286b17eeb0894

c10/Half-inl.h
c10/Half.cpp
c10/Half.h
c10/core/bitcasts.h [new file with mode: 0644]
c10/test/util/Half_test.cpp [new file with mode: 0644]

index 3edaf5c..966d55f 100644 (file)
@@ -20,7 +20,7 @@ inline C10_HOST_DEVICE Half::Half(float value) {
 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
   x = __half_as_short(__float2half(value));
 #else
-  x = detail::float2halfbits(value);
+  x = detail::fp16_ieee_from_fp32_value(value);
 #endif
 }
 
@@ -30,7 +30,7 @@ inline C10_HOST_DEVICE Half::operator float() const {
 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
   return __half2float(*reinterpret_cast<const __half*>(&x));
 #else
-  return detail::halfbits2float(x);
+  return detail::fp16_ieee_to_fp32_value(x);
 #endif
 }
 
index 69cf942..f90d30b 100644 (file)
@@ -8,95 +8,6 @@ static_assert(
     std::is_standard_layout<Half>::value,
     "c10::Half must be standard layout.");
 
-namespace detail {
-
-// Host functions for converting between FP32 and FP16 formats
-
-float halfbits2float(unsigned short h) {
-  unsigned sign = ((h >> 15) & 1);
-  unsigned exponent = ((h >> 10) & 0x1f);
-  unsigned mantissa = ((h & 0x3ff) << 13);
-
-  if (exponent == 0x1f) { /* NaN or Inf */
-    mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
-    exponent = 0xff;
-  } else if (!exponent) { /* Denorm or Zero */
-    if (mantissa) {
-      unsigned int msb;
-      exponent = 0x71;
-      do {
-        msb = (mantissa & 0x400000);
-        mantissa <<= 1; /* normalize */
-        --exponent;
-      } while (!msb);
-      mantissa &= 0x7fffff; /* 1.mantissa is implicit */
-    }
-  } else {
-    exponent += 0x70;
-  }
-
-  unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
-
-  // Reinterpret the result bit pattern as a float
-  float result_float;
-  std::memcpy(&result_float, &result_bit, sizeof(result_float));
-  return result_float;
-}
-
-unsigned short float2halfbits(float src) {
-  // Reinterpret the float as a bit pattern
-  unsigned x;
-  std::memcpy(&x, &src, sizeof(x));
-
-  unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
-  unsigned sign, exponent, mantissa;
-
-  // Get rid of +NaN/-NaN case first.
-  if (u > 0x7f800000) {
-    return 0x7fffU;
-  }
-
-  sign = ((x >> 16) & 0x8000);
-
-  // Get rid of +Inf/-Inf, +0/-0.
-  if (u > 0x477fefff) {
-    return sign | 0x7c00U;
-  }
-  if (u < 0x33000001) {
-    return (sign | 0x0000);
-  }
-
-  exponent = ((u >> 23) & 0xff);
-  mantissa = (u & 0x7fffff);
-
-  if (exponent > 0x70) {
-    shift = 13;
-    exponent -= 0x70;
-  } else {
-    shift = 0x7e - exponent;
-    exponent = 0;
-    mantissa |= 0x800000;
-  }
-  lsb = (1 << shift);
-  lsb_s1 = (lsb >> 1);
-  lsb_m1 = (lsb - 1);
-
-  // Round to nearest even.
-  remainder = (mantissa & lsb_m1);
-  mantissa >>= shift;
-  if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
-    ++mantissa;
-    if (!(mantissa & 0x3ff)) {
-      ++exponent;
-      mantissa = 0;
-    }
-  }
-
-  return (sign | (exponent << 10) | mantissa);
-}
-
-} // namespace detail
-
 std::ostream& operator<<(std::ostream& out, const Half& value) {
   out << (float)value;
   return out;
index de25486..9765898 100644 (file)
@@ -9,12 +9,24 @@
 /// If you are writing a compute bound kernel, you can use the CUDA half
 /// intrinsics directly on the Half type from device code.
 
+#include <c10/core/bitcasts.h>
 #include <c10/macros/Macros.h>
 #include <c10/util/C++17.h>
 
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
 #include <cmath>
-#include <complex>
 #include <cstdint>
+#elif !defined(__OPENCL_VERSION__)
+#include <math.h>
+#include <stdint.h>
+#endif
+
+#ifdef _MSC_VER
+#include <intrin.h>
+#endif
+
+#include <complex>
+#include <cstring>
 #include <iosfwd>
 #include <limits>
 #include <sstream>
@@ -34,8 +46,246 @@ namespace c10 {
 
 namespace detail {
 
-C10_API float halfbits2float(unsigned short bits);
-C10_API unsigned short float2halfbits(float value);
+  /*
+   * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to
+   * a 32-bit floating-point number in IEEE single-precision format, in bit representation.
+   *
+   * @note The implementation doesn't use any floating-point operations.
+   */
+  static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
+       /*
+        * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word:
+        *      +---+-----+------------+-------------------+
+        *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+        *      +---+-----+------------+-------------------+
+        * Bits  31  26-30    16-25            0-15
+        *
+        * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits.
+        */
+       const uint32_t w = (uint32_t) h << 16;
+       /*
+        * Extract the sign of the input number into the high bit of the 32-bit word:
+        *
+        *      +---+----------------------------------+
+        *      | S |0000000 00000000 00000000 00000000|
+        *      +---+----------------------------------+
+        * Bits  31                 0-31
+        */
+       const uint32_t sign = w & UINT32_C(0x80000000);
+       /*
+        * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word:
+        *
+        *      +---+-----+------------+-------------------+
+        *      | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+        *      +---+-----+------------+-------------------+
+        * Bits  30  27-31     17-26            0-16
+        */
+       const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
+       /*
+        * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized.
+        * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one.
+        * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift
+        * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the
+        * biased exponent into 1, and making mantissa normalized (i.e. without leading 1).
+        */
+#ifdef _MSC_VER
+        unsigned long nonsign_bsr;
+        _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
+        uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
+#else
+        uint32_t renorm_shift = __builtin_clz(nonsign);
+#endif
+        renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
+        /*
+         * Iff half-precision number has exponent of 15, the addition overflows
+         * it into bit 31, and the subsequent shift turns the high 9 bits
+         * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
+         * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
+         */
+        const int32_t inf_nan_mask =
+            ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
+        /*
+         * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
+         * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
+         * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
+         * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
+         * 0x00000000 otherwise
+         */
+        const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
+        /*
+         * 1. Shift nonsign left by renorm_shift to normalize it (if the input
+         * was denormal)
+         * 2. Shift nonsign right by 3 so the exponent (5 bits originally)
+         * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
+         * bits of the 23-bit mantissa of IEEE single-precision number.
+         * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
+         * different in exponent bias (0x7F for single-precision number less 0xF
+         * for half-precision number).
+         * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
+         * account for renormalization. As renorm_shift is less than 0x70, this
+         * can be combined with step 3.
+         * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
+         * input was NaN or infinity.
+         * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
+         * into zero if the input was zero.
+         * 7. Combine with the sign of the input number.
+         */
+        return sign |
+            ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
+              inf_nan_mask) &
+             ~zero_mask);
+  }
+
+  /*
+   * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to
+   * a 32-bit floating-point number in IEEE single-precision format.
+   *
+   * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+   * floating-point operations and bitcasts between integer and floating-point variables.
+   */
+  static inline float fp16_ieee_to_fp32_value(uint16_t h) {
+       /*
+        * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word:
+        *      +---+-----+------------+-------------------+
+        *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+        *      +---+-----+------------+-------------------+
+        * Bits  31  26-30    16-25            0-15
+        *
+        * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits.
+        */
+       const uint32_t w = (uint32_t) h << 16;
+       /*
+        * Extract the sign of the input number into the high bit of the 32-bit word:
+        *
+        *      +---+----------------------------------+
+        *      | S |0000000 00000000 00000000 00000000|
+        *      +---+----------------------------------+
+        * Bits  31                 0-31
+        */
+       const uint32_t sign = w & UINT32_C(0x80000000);
+       /*
+        * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word:
+        *
+        *      +-----+------------+---------------------+
+        *      |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
+        *      +-----+------------+---------------------+
+        * Bits  27-31    17-26            0-16
+        */
+       const uint32_t two_w = w + w;
+
+       /*
+        * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent
+        * of a single-precision floating-point number:
+        *
+        *       S|Exponent |          Mantissa
+        *      +-+---+-----+------------+----------------+
+        *      |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
+        *      +-+---+-----+------------+----------------+
+        * Bits   | 23-31   |           0-22
+        *
+        * Next, there are some adjustments to the exponent:
+        * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision
+        *   formats (0x7F - 0xF = 0x70)
+        * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number.
+        *   Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent
+        *   of the single-precision output must be 0xFF (max possible value). We do this correction in two steps:
+        *   - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested
+        *     by the difference in the exponent bias (see above).
+        *   - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of
+        *     exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias.
+        *     The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least
+        *     partially IEEE754-compliant implementations.
+        *
+        * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not
+        * operate on denormal inputs, and do not produce denormal results.
+        */
+       const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+    // const float exp_scale = 0x1.0p-112f;
+    uint32_t scale_bits = (uint32_t) 15 << 23;
+    float exp_scale_val;
+    std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
+    const float exp_scale = exp_scale_val;
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+       /*
+        * Convert denormalized half-precision inputs into single-precision results (always normalized).
+        * Zero inputs are also handled here.
+        *
+        * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits.
+        * First, we shift mantissa into bits 0-9 of the 32-bit word.
+        *
+        *                  zeros           |  mantissa
+        *      +---------------------------+------------+
+        *      |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
+        *      +---------------------------+------------+
+        * Bits             10-31                0-9
+        *
+        * Now, remember that denormalized half-precision numbers are represented as:
+        *    FP16 = mantissa * 2**(-24).
+        * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input
+        * and with an exponent which would scale the corresponding mantissa bits to 2**(-24).
+        * A normalized single-precision floating-point number is represented as:
+        *    FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127)
+        * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision
+        * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount.
+        *
+        * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number
+        * is zero, the constructed single-precision number has the value of
+        *    FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5
+        * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of
+        * the input half-precision number.
+        */
+       const uint32_t magic_mask = UINT32_C(126) << 23;
+       const float magic_bias = 0.5f;
+       const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+       /*
+        * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the
+        *   input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the
+        *   input is either a denormal number, or zero.
+        * - Combine the result of conversion of exponent and mantissa with the sign of the input number.
+        */
+       const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+       const uint32_t result = sign |
+               (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+       return fp32_from_bits(result);
+  }
+
+  /*
+   * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in
+   * IEEE half-precision format, in bit representation.
+   *
+   * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+   * floating-point operations and bitcasts between integer and floating-point variables.
+   */
+  static inline uint16_t fp16_ieee_from_fp32_value(float f) {
+    // const float scale_to_inf = 0x1.0p+112f;
+    // const float scale_to_zero = 0x1.0p-110f;
+    uint32_t scale_to_inf_bits = (uint32_t) 239 << 23;
+    uint32_t scale_to_zero_bits = (uint32_t) 17 << 23;
+    float scale_to_inf_val, scale_to_zero_val;
+    std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
+    std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
+    const float scale_to_inf = scale_to_inf_val;
+    const float scale_to_zero = scale_to_zero_val;
+
+       float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+       const uint32_t w = fp32_to_bits(f);
+       const uint32_t shl1_w = w + w;
+       const uint32_t sign = w & UINT32_C(0x80000000);
+       uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+       if (bias < UINT32_C(0x71000000)) {
+               bias = UINT32_C(0x71000000);
+       }
+
+       base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+       const uint32_t bits = fp32_to_bits(base);
+       const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+       const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+       const uint32_t nonsign = exp_bits + mantissa_bits;
+       return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+  }
 
 } // namespace detail
 
diff --git a/c10/core/bitcasts.h b/c10/core/bitcasts.h
new file mode 100644 (file)
index 0000000..eb8b502
--- /dev/null
@@ -0,0 +1,45 @@
+#pragma once
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+#include <cstdint>
+#elif !defined(__OPENCL_VERSION__)
+#include <stdint.h>
+#endif
+
+namespace c10 {
+namespace detail {
+
+static inline float fp32_from_bits(uint32_t w) {
+#if defined(__OPENCL_VERSION__)
+  return as_float(w);
+#elif defined(__CUDA_ARCH__)
+  return __uint_as_float((unsigned int)w);
+#elif defined(__INTEL_COMPILER)
+  return _castu32_f32(w);
+#else
+  union {
+    uint32_t as_bits;
+    float as_value;
+  } fp32 = {w};
+  return fp32.as_value;
+#endif
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+#if defined(__OPENCL_VERSION__)
+  return as_uint(f);
+#elif defined(__CUDA_ARCH__)
+  return (uint32_t)__float_as_uint(f);
+#elif defined(__INTEL_COMPILER)
+  return _castf32_u32(f);
+#else
+  union {
+    float as_value;
+    uint32_t as_bits;
+  } fp32 = {f};
+  return fp32.as_bits;
+#endif
+}
+
+} // namespace detail
+} // namespace c10
diff --git a/c10/test/util/Half_test.cpp b/c10/test/util/Half_test.cpp
new file mode 100644 (file)
index 0000000..525d58d
--- /dev/null
@@ -0,0 +1,108 @@
+#include <vector>
+
+#include <c10/Half.h>
+#include <gtest/gtest.h>
+
+namespace {
+namespace half_legacy_impl {
+float halfbits2float(unsigned short h) {
+  unsigned sign = ((h >> 15) & 1);
+  unsigned exponent = ((h >> 10) & 0x1f);
+  unsigned mantissa = ((h & 0x3ff) << 13);
+
+  if (exponent == 0x1f) { /* NaN or Inf */
+    mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
+    exponent = 0xff;
+  } else if (!exponent) { /* Denorm or Zero */
+    if (mantissa) {
+      unsigned int msb;
+      exponent = 0x71;
+      do {
+        msb = (mantissa & 0x400000);
+        mantissa <<= 1; /* normalize */
+        --exponent;
+      } while (!msb);
+      mantissa &= 0x7fffff; /* 1.mantissa is implicit */
+    }
+  } else {
+    exponent += 0x70;
+  }
+
+  unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
+
+  // Reinterpret the result bit pattern as a float
+  float result_float;
+  std::memcpy(&result_float, &result_bit, sizeof(result_float));
+  return result_float;
+};
+
+unsigned short float2halfbits(float src) {
+  // Reinterpret the float as a bit pattern
+  unsigned x;
+  std::memcpy(&x, &src, sizeof(x));
+
+  unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
+  unsigned sign, exponent, mantissa;
+
+  // Get rid of +NaN/-NaN case first.
+  if (u > 0x7f800000) {
+    return 0x7fffU;
+  }
+
+  sign = ((x >> 16) & 0x8000);
+
+  // Get rid of +Inf/-Inf, +0/-0.
+  if (u > 0x477fefff) {
+    return sign | 0x7c00U;
+  }
+  if (u < 0x33000001) {
+    return (sign | 0x0000);
+  }
+
+  exponent = ((u >> 23) & 0xff);
+  mantissa = (u & 0x7fffff);
+
+  if (exponent > 0x70) {
+    shift = 13;
+    exponent -= 0x70;
+  } else {
+    shift = 0x7e - exponent;
+    exponent = 0;
+    mantissa |= 0x800000;
+  }
+  lsb = (1 << shift);
+  lsb_s1 = (lsb >> 1);
+  lsb_m1 = (lsb - 1);
+
+  // Round to nearest even.
+  remainder = (mantissa & lsb_m1);
+  mantissa >>= shift;
+  if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
+    ++mantissa;
+    if (!(mantissa & 0x3ff)) {
+      ++exponent;
+      mantissa = 0;
+    }
+  }
+
+  return (sign | (exponent << 10) | mantissa);
+};
+} // namespace half_legacy_impl
+TEST(HalfDoubleConversionTest, Half2Double) {
+  std::vector<uint16_t> inputs = {
+      0,
+      0xfbff, // 1111 1011 1111 1111
+      (1 << 15 | 1),
+      0x7bff // 0111 1011 1111 1111
+  };
+  for (auto x : inputs) {
+    auto target = c10::detail::fp16_ieee_to_fp32_value(x);
+    EXPECT_EQ(half_legacy_impl::halfbits2float(x), target)
+        << "Test failed for uint16 to float " << x << "\n";
+    EXPECT_EQ(
+        half_legacy_impl::float2halfbits(target),
+        c10::detail::fp16_ieee_from_fp32_value(target))
+        << "Test failed for float to uint16" << target << "\n";
+  }
+}
+} // namespace