From: Chandler Zuo Date: Tue, 4 Dec 2018 22:23:22 +0000 (-0800) Subject: Replace at::Half non-vectorized conversions with implementations from FP16 (#14411) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2481 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5ed9dfad982a4e20a51fdf30d3620e76b87ecff4;p=platform%2Fupstream%2Fpytorch.git Replace at::Half non-vectorized conversions with implementations from FP16 (#14411) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14579 Folded the fp16 codes into c10. Reviewed By: ezyang Differential Revision: D13206450 fbshipit-source-id: 472208dd230dc49d33935622ff3286b17eeb0894 --- diff --git a/c10/Half-inl.h b/c10/Half-inl.h index 3edaf5c..966d55f 100644 --- a/c10/Half-inl.h +++ b/c10/Half-inl.h @@ -20,7 +20,7 @@ inline C10_HOST_DEVICE Half::Half(float value) { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) x = __half_as_short(__float2half(value)); #else - x = detail::float2halfbits(value); + x = detail::fp16_ieee_from_fp32_value(value); #endif } @@ -30,7 +30,7 @@ inline C10_HOST_DEVICE Half::operator float() const { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) return __half2float(*reinterpret_cast(&x)); #else - return detail::halfbits2float(x); + return detail::fp16_ieee_to_fp32_value(x); #endif } diff --git a/c10/Half.cpp b/c10/Half.cpp index 69cf942..f90d30b 100644 --- a/c10/Half.cpp +++ b/c10/Half.cpp @@ -8,95 +8,6 @@ static_assert( std::is_standard_layout::value, "c10::Half must be standard layout."); -namespace detail { - -// Host functions for converting between FP32 and FP16 formats - -float halfbits2float(unsigned short h) { - unsigned sign = ((h >> 15) & 1); - unsigned exponent = ((h >> 10) & 0x1f); - unsigned mantissa = ((h & 0x3ff) << 13); - - if (exponent == 0x1f) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x71; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x70; - } - - unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa; - - // Reinterpret the result bit pattern as a float - float result_float; - std::memcpy(&result_float, &result_bit, sizeof(result_float)); - return result_float; -} - -unsigned short float2halfbits(float src) { - // Reinterpret the float as a bit pattern - unsigned x; - std::memcpy(&x, &src, sizeof(x)); - - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - return 0x7fffU; - } - - sign = ((x >> 16) & 0x8000); - - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - return sign | 0x7c00U; - } - if (u < 0x33000001) { - return (sign | 0x0000); - } - - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); - - if (exponent > 0x70) { - shift = 13; - exponent -= 0x70; - } else { - shift = 0x7e - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0x3ff)) { - ++exponent; - mantissa = 0; - } - } - - return (sign | (exponent << 10) | mantissa); -} - -} // namespace detail - std::ostream& operator<<(std::ostream& out, const Half& value) { out << (float)value; return out; diff --git a/c10/Half.h b/c10/Half.h index de25486..9765898 100644 --- a/c10/Half.h +++ b/c10/Half.h @@ -9,12 +9,24 @@ /// If you are writing a compute bound kernel, you can use the CUDA half /// intrinsics directly on the Half type from device code. +#include #include #include +#if defined(__cplusplus) && (__cplusplus >= 201103L) #include -#include #include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include #include #include #include @@ -34,8 +46,246 @@ namespace c10 { namespace detail { -C10_API float halfbits2float(unsigned short bits); -C10_API unsigned short float2halfbits(float value); + /* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ + static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + } + + /* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ + static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + uint32_t scale_bits = (uint32_t) 15 << 23; + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } + + /* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ + static inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + uint32_t scale_to_inf_bits = (uint32_t) 239 << 23; + uint32_t scale_to_zero_bits = (uint32_t) 17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + } } // namespace detail diff --git a/c10/core/bitcasts.h b/c10/core/bitcasts.h new file mode 100644 index 0000000..eb8b502 --- /dev/null +++ b/c10/core/bitcasts.h @@ -0,0 +1,45 @@ +#pragma once + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#endif + +namespace c10 { +namespace detail { + +static inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +#endif +} + +static inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +#endif +} + +} // namespace detail +} // namespace c10 diff --git a/c10/test/util/Half_test.cpp b/c10/test/util/Half_test.cpp new file mode 100644 index 0000000..525d58d --- /dev/null +++ b/c10/test/util/Half_test.cpp @@ -0,0 +1,108 @@ +#include + +#include +#include + +namespace { +namespace half_legacy_impl { +float halfbits2float(unsigned short h) { + unsigned sign = ((h >> 15) & 1); + unsigned exponent = ((h >> 10) & 0x1f); + unsigned mantissa = ((h & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa; + + // Reinterpret the result bit pattern as a float + float result_float; + std::memcpy(&result_float, &result_bit, sizeof(result_float)); + return result_float; +}; + +unsigned short float2halfbits(float src) { + // Reinterpret the float as a bit pattern + unsigned x; + std::memcpy(&x, &src, sizeof(x)); + + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + return 0x7fffU; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + return sign | 0x7c00U; + } + if (u < 0x33000001) { + return (sign | 0x0000); + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + return (sign | (exponent << 10) | mantissa); +}; +} // namespace half_legacy_impl +TEST(HalfDoubleConversionTest, Half2Double) { + std::vector inputs = { + 0, + 0xfbff, // 1111 1011 1111 1111 + (1 << 15 | 1), + 0x7bff // 0111 1011 1111 1111 + }; + for (auto x : inputs) { + auto target = c10::detail::fp16_ieee_to_fp32_value(x); + EXPECT_EQ(half_legacy_impl::halfbits2float(x), target) + << "Test failed for uint16 to float " << x << "\n"; + EXPECT_EQ( + half_legacy_impl::float2halfbits(target), + c10::detail::fp16_ieee_from_fp32_value(target)) + << "Test failed for float to uint16" << target << "\n"; + } +} +} // namespace